aboutsummaryrefslogtreecommitdiffstats
path: root/common_gdal.py
diff options
context:
space:
mode:
Diffstat (limited to 'common_gdal.py')
-rw-r--r--common_gdal.py383
1 files changed, 383 insertions, 0 deletions
diff --git a/common_gdal.py b/common_gdal.py
new file mode 100644
index 0000000..b5570eb
--- /dev/null
+++ b/common_gdal.py
@@ -0,0 +1,383 @@
+#!/usr/bin/python3
+
+#----------------------------------------------------------------------
+# Backend utilities for the Klimatanalys Norr project (common GDAL functions)
+# Copyright © 2024-2025 Guilhem Moulin <info@guilhem.se>
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#----------------------------------------------------------------------
+
+# pylint: disable=invalid-name, missing-module-docstring
+
+import logging
+import math
+import re
+from typing import Any, Optional
+
+from osgeo import gdal, ogr, osr
+
+from common import BadConfiguration, escape_identifier, getEscapedTableNamePG
+
+# pylint: disable-next=redefined-builtin
+def gdalVersionMin(maj : int = 0, min : int = 0, rev : int = 0) -> bool:
+ """Return a boolean indicating whether the installer GDAL version is
+ greater than or equal to the provider (maj, min, rev) triplet."""
+
+ if maj < 1 or (maj == 1 and min < 10):
+ # GDAL_VERSION_NUM() macro was changed in 1.10. That version
+ # was released in 2013 so we blindly assume the installer
+ # version is more recent
+ return True
+
+ version_cur = int(gdal.VersionInfo())
+ # cf. GDAL_COMPUTE_VERSION(maj,min,rev) in gcore/gdal_version.h.in
+ version_min = maj*1000000 + min*10000 + rev*100
+ return version_min <= version_cur
+
+def gdalGetMetadataItem(obj : gdal.MajorObject, k : str) -> bool:
+ """Wrapper around gdal.MajorObject.GetMetadataItem(name)."""
+
+ v = obj.GetMetadataItem(k)
+ if v is not None and isinstance(v, str):
+ return v.upper() == 'YES'
+
+ return False
+
+# pylint: disable-next=dangerous-default-value
+def gdalSetOpenExArgs(option_dict : Optional[dict[str, Any]] = {},
+ flags : int = 0) -> tuple[dict[str, int|list[str]], gdal.Driver]:
+ """Return a pair kwargs and driver to use with gdal.OpenEx()."""
+
+ kwargs = { 'nOpenFlags': flags }
+
+ fmt = option_dict.get('format', None)
+ if fmt is None:
+ drv = None
+ else:
+ drv = gdal.GetDriverByName(fmt)
+ if drv is None:
+ raise RuntimeError(f'Unknown driver name "{fmt}"')
+ if flags & gdal.OF_VECTOR and not gdalGetMetadataItem(drv, gdal.DCAP_VECTOR):
+ raise RuntimeError(f'Driver "{drv.ShortName}" has no vector capabilities')
+ if flags & gdal.OF_RASTER and not gdalGetMetadataItem(drv, gdal.DCAP_RASTER):
+ raise RuntimeError(f'Driver "{drv.ShortName}" has no raster capabilities')
+ kwargs['allowed_drivers'] = [ drv.ShortName ]
+
+ oo = option_dict.get('open-options', None)
+ if oo is not None:
+ kwargs['open_options'] = [ k + '=' + str(v) for k, v in oo.items() ]
+ return kwargs, drv
+
+def getSRS(srs_str : Optional[str]) -> osr.SpatialReference:
+ """Return the decoded Spatial Reference System."""
+
+ if srs_str is None:
+ return None
+
+ srs = osr.SpatialReference()
+ if srs_str.startswith('EPSG:'):
+ code = int(srs_str.removeprefix('EPSG:'))
+ srs.ImportFromEPSG(code)
+ else:
+ raise RuntimeError(f'Unknown SRS {srs_str}')
+
+ logging.debug('Default SRS: "%s" (%s)', srs.ExportToProj4(), srs.GetName())
+ return srs
+
+def getExtent(extent : Optional[tuple[float, float, float, float]],
+ srs : Optional[osr.SpatialReference] = None) -> ogr.Geometry:
+ """Convert extent (minX, minY, maxX, maxY) into a polygon and assign the
+ given SRS."""
+ if extent is None:
+ return None
+
+ if not isinstance(extent, tuple) or len(extent) != 4:
+ raise RuntimeError(f'Invalid extent {extent}')
+ if srs is None:
+ raise RuntimeError('Configured extent but no SRS')
+
+ logging.debug('Configured extent in %s: %s', srs.GetName(),
+ ', '.join(map(str, extent)))
+
+ ring = ogr.Geometry(ogr.wkbLinearRing)
+ ring.AddPoint_2D(extent[0], extent[1])
+ ring.AddPoint_2D(extent[2], extent[1])
+ ring.AddPoint_2D(extent[2], extent[3])
+ ring.AddPoint_2D(extent[0], extent[3])
+ ring.AddPoint_2D(extent[0], extent[1])
+
+ polygon = ogr.Geometry(ogr.wkbPolygon)
+ polygon.AddGeometry(ring)
+
+ # we expressed extent as minX, minY, maxX, maxY (easting/northing
+ # ordered, i.e., in traditional GIS order)
+ srs2 = srs.Clone()
+ srs2.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER)
+ polygon.AssignSpatialReference(srs2)
+ if not srs2.IsSame(srs):
+ polygon.TransformTo(srs)
+
+ return polygon
+
+def getSpatialFilterFromGeometry(geom : ogr.Geometry, srs : osr.SpatialReference) -> ogr.Geometry:
+ """Make the geometry suitable for use as a spatial filter. It is
+ densified the SRS:s are not equivalent."""
+ cloned = False
+ geom_srs = geom.GetSpatialReference()
+ if not geom_srs.IsSame(srs, [ 'IGNORE_DATA_AXIS_TO_SRS_AXIS_MAPPING=YES',
+ 'CRITERION=EQUIVALENT']):
+ # densify the geometry (a rectangle) to avoid issues when reprojecting,
+ # cf. apps/ogr2ogr_lib.cpp:ApplySpatialFilter()
+ segment_distance_metre = 10 * 1000
+ if geom_srs.IsGeographic():
+ cloned = True
+ geom = geom.Clone()
+ dfMaxLength = segment_distance_metre / math.radians(geom_srs.GetSemiMajor())
+ geom.Segmentize(dfMaxLength)
+ elif geom_srs.IsProjected():
+ cloned = True
+ geom = geom.Clone()
+ dfMaxLength = segment_distance_metre / geom_srs.GetLinearUnits()
+ geom.Segmentize(dfMaxLength)
+
+ if geom_srs.IsSame(srs):
+ return geom
+
+ if not cloned:
+ geom = geom.Clone()
+ if geom.TransformTo(srs) != ogr.OGRERR_NONE:
+ raise RuntimeError(f'Could not transform {geom.ExportToWkt()} to {srs.GetName()}')
+ return geom
+
+def formatTZFlag(tzFlag : int) -> str:
+ """Pretty-print timezone flag, cf. ogr/ogrutils.cpp:OGRGetISO8601DateTime()"""
+ if tzFlag is None:
+ raise RuntimeError('printTimeZone(None)')
+ if tzFlag == ogr.TZFLAG_UNKNOWN:
+ return 'none'
+ if tzFlag == ogr.TZFLAG_LOCALTIME:
+ return 'local'
+ if tzFlag == ogr.TZFLAG_UTC:
+ return 'UTC'
+
+ tzOffset = abs(tzFlag - ogr.TZFLAG_UTC) * 15
+ tzHour = int(tzOffset / 60)
+ tzMinute = int(tzOffset % 60)
+ tzSign = '+' if tzFlag > ogr.TZFLAG_UTC else '-'
+ return f'{tzSign}{tzHour:02}{tzMinute:02}'
+
+def fromGeomTypeName(name : str) -> int:
+ """Parse a Geometry type name, cf. ogr/ogrgeometry.cpp:OGRFromOGCGeomType()"""
+ name = name.upper()
+
+ isMeasured = False
+ if name.endswith('M'):
+ isMeasured = True
+ name = name.removesuffix('M')
+
+ convertTo3D = False
+ if name.endswith('Z'):
+ convertTo3D = True
+ name = name.removesuffix('Z')
+
+ if name == 'POINT':
+ eGType = ogr.wkbPoint
+ elif name == 'LINESTRING':
+ eGType = ogr.wkbLineString
+ elif name == 'POLYGON':
+ eGType = ogr.wkbPolygon
+ elif name == 'MULTIPOINT':
+ eGType = ogr.wkbMultiPoint
+ elif name == 'MULTILINESTRING':
+ eGType = ogr.wkbMultiLineString
+ elif name == 'MULTIPOLYGON':
+ eGType = ogr.wkbMultiPolygon
+ elif name == 'GEOMETRYCOLLECTION':
+ eGType = ogr.wkbGeometryCollection
+ elif name == 'CIRCULARSTRING':
+ eGType = ogr.wkbCircularString
+ elif name == 'COMPOUNDCURVE':
+ eGType = ogr.wkbCompoundCurve
+ elif name == 'CURVEPOLYGON':
+ eGType = ogr.wkbCurvePolygon
+ elif name == 'MULTICURVE':
+ eGType = ogr.wkbMultiCurve
+ elif name == 'MULTISURFACE':
+ eGType = ogr.wkbMultiSurface
+ elif name == 'TRIANGLE':
+ eGType = ogr.wkbTriangle
+ elif name == 'POLYHEDRALSURFACE':
+ eGType = ogr.wkbPolyhedralSurface
+ elif name == 'TIN':
+ eGType = ogr.wkbTIN
+ elif name == 'CURVE':
+ eGType = ogr.wkbCurve
+ elif name == 'SURFACE':
+ eGType = ogr.wkbSurface
+ else:
+ eGType = ogr.wkbUnknown
+
+ if convertTo3D:
+ eGType = ogr.GT_SetZ(eGType)
+
+ if isMeasured:
+ eGType = ogr.GT_SetM(eGType)
+
+ return eGType
+
+def parseGeomType(name : str|None) -> int:
+ """Parse geometry type, cf. ogr2ogr_lib.cpp"""
+ if name is None:
+ return ogr.wkbUnknown
+ name2 = name.upper()
+
+ is3D = False
+ if name2.endswith('25D'):
+ name2 = name2[:-3] # alias
+ is3D = True
+ elif name2.endswith('Z'):
+ name2 = name2[:-1]
+ is3D = True
+
+ if name2 == 'NONE':
+ eGType = ogr.wkbNone
+ elif name2 == 'GEOMETRY':
+ eGType = ogr.wkbUnknown
+ else:
+ eGType = fromGeomTypeName(name2)
+ if eGType == ogr.wkbUnknown:
+ raise BadConfiguration(f'Unknown geometry type "{name}"')
+
+ if eGType != ogr.wkbNone and is3D:
+ eGType = ogr.GT_SetZ(eGType)
+
+ return eGType
+
+
+def parseFieldType(name : str) -> int:
+ """Parse field type, cf. ogr/ogr_core.h's enum OGRFieldType"""
+ # pylint: disable=too-many-return-statements
+ if name is None:
+ raise RuntimeError('parseFieldType(None)')
+
+ name2 = name.lower()
+ if name2 == 'integer':
+ # simple 32bit integer
+ return ogr.OFTInteger
+ if name2 == 'integerlist':
+ # List of 32bit integers
+ return ogr.OFTIntegerList
+ if name2 == 'real':
+ # Double Precision floating point
+ return ogr.OFTReal
+ if name2 == 'reallist':
+ # List of doubles
+ return ogr.OFTRealList
+ if name2 == 'string':
+ # String of ASCII chars
+ return ogr.OFTString
+ if name2 == 'stringlist':
+ # Array of strings
+ return ogr.OFTStringList
+ if name2 == 'binary':
+ # Raw Binary data
+ return ogr.OFTBinary
+ if name2 == 'date':
+ # Date
+ return ogr.OFTDate
+ if name2 == 'time':
+ # Time
+ return ogr.OFTTime
+ if name2 == 'datetime':
+ # Date and Time
+ return ogr.OFTDateTime
+ if name2 == 'integer64':
+ # Single 64bit integer
+ return ogr.OFTInteger64
+ if name2 == 'integer64list':
+ # List of 64bit integers
+ return ogr.OFTInteger64List
+ raise BadConfiguration(f'Unknown field type "{name}"')
+
+def parseSubFieldType(name : str) -> int:
+ """Parse subfield type, cf. ogr/ogr_core.h's enum OGRFieldSubType"""
+ if name is None:
+ raise RuntimeError('parseSubFieldType(None)')
+ name2 = name.lower()
+ if name2 == 'none':
+ # No subtype. This is the default value.
+ return ogr.OFSTNone
+ if name2 == 'bool':
+ # Boolean integer. Only valid for OFTInteger and OFTIntegerList.
+ return ogr.OFSTBoolean
+ if name2 == 'int16':
+ # Signed 16-bit integer. Only valid for OFTInteger and OFTIntegerList.
+ return ogr.OFSTInt16
+ if name2 == 'float32':
+ # Single precision (32 bit) floating point. Only valid for OFTReal and OFTRealList.
+ return ogr.OFSTFloat32
+ if name2 == 'json':
+ # JSON content. Only valid for OFTString.
+ return ogr.OFSTJSON
+ if name2 == 'uuid':
+ # UUID string representation. Only valid for OFTString.
+ return ogr.OFSTUUID
+ raise BadConfiguration(f'Unknown field subtype "{name}"')
+
+TZ_RE = re.compile(r'(?:UTC\b)?([\+\-]?)([0-9][0-9]):?([0-9][0-9])', flags=re.IGNORECASE)
+def parseTimeZone(tz : str) -> int:
+ """Parse timezone."""
+ if tz is None:
+ raise RuntimeError('parseTimeZone(None)')
+ tz2 = tz.lower()
+ if tz2 == 'none':
+ return ogr.TZFLAG_UNKNOWN
+ if tz2 == 'local':
+ return ogr.TZFLAG_LOCALTIME
+ if tz2 in ('utc', 'gmt'):
+ return ogr.TZFLAG_UTC
+
+ m = TZ_RE.fullmatch(tz)
+ if m is None:
+ raise BadConfiguration(f'Invalid timezone "{tz}"')
+ tzSign = m.group(1)
+ tzHour = int(m.group(2))
+ tzMinute = int(m.group(3))
+ if tzHour > 14 or tzMinute >= 60 or tzMinute % 15 != 0:
+ raise BadConfiguration(f'Invalid timezone "{tz}"')
+ tzFlag = tzHour*4 + int(tzMinute/15)
+ if tzSign == '-':
+ tzFlag = ogr.TZFLAG_UTC - tzFlag
+ else:
+ tzFlag += ogr.TZFLAG_UTC
+ return tzFlag
+
+def getEscapedTableName(lyr : ogr.Layer, extract_schema_from_layer_name : bool = True) -> str:
+ """Return the layer name as an escaped identifier, suitable for injection into SQL queries.
+ For the PostgreSQL driver, an optional boolean (default: True) indicates whether the first
+ dot character is used as the separator between the schema and the table name."""
+ layername = lyr.GetName()
+ if lyr.GetDataset().GetDriver().ShortName == 'PostgreSQL':
+ return getEscapedTableNamePG(layername, extract_schema_from_layer_name)
+ return escape_identifier(layername)
+
+def executeSQL(ds : gdal.Dataset, statement : str,
+ spatialFilter : Optional[ogr.Geometry] = None) -> ogr.Layer|None:
+ """Wrapper for gdal.Dataset.ExecuteSQL().
+ https://gdal.org/en/stable/api/python/raster_api.html#osgeo.gdal.Dataset.ExecuteSQL"""
+ msg = statement
+ if spatialFilter is not None:
+ msg += ', spatialFilter=' + spatialFilter.ExportToWkt()
+ logging.debug('ExecuteSQL(%s)', msg)
+ return ds.ExecuteSQL(statement, spatialFilter=spatialFilter)