aboutsummaryrefslogtreecommitdiffstats
path: root/import_source.py
diff options
context:
space:
mode:
Diffstat (limited to 'import_source.py')
-rw-r--r--import_source.py1153
1 files changed, 1153 insertions, 0 deletions
diff --git a/import_source.py b/import_source.py
new file mode 100644
index 0000000..f802a57
--- /dev/null
+++ b/import_source.py
@@ -0,0 +1,1153 @@
+#!/usr/bin/python3
+
+#----------------------------------------------------------------------
+# Backend utilities for the Klimatanalys Norr project (import source layers)
+# Copyright © 2024-2025 Guilhem Moulin <info@guilhem.se>
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#----------------------------------------------------------------------
+
+# pylint: disable=invalid-name, missing-module-docstring, fixme
+
+import logging
+import tempfile
+import re
+from fnmatch import fnmatchcase
+from pathlib import Path
+from datetime import datetime, timedelta, UTC
+from typing import Any, Callable, Final, Iterator, Optional
+import traceback
+from enum import Enum, unique as enum_unique
+from hashlib import sha256
+import struct
+
+from osgeo import gdal, ogr, osr
+from osgeo.gdalconst import (
+ OF_ALL as GDAL_OF_ALL,
+ OF_READONLY as GDAL_OF_READONLY,
+ OF_UPDATE as GDAL_OF_UPDATE,
+ OF_VECTOR as GDAL_OF_VECTOR,
+ OF_VERBOSE_ERROR as GDAL_OF_VERBOSE_ERROR,
+ DCAP_CREATE as GDAL_DCAP_CREATE,
+)
+from osgeo import gdalconst
+
+from common import BadConfiguration, escape_identifier, escape_literal_str
+from common_gdal import (
+ gdalSetOpenExArgs,
+ gdalGetMetadataItem,
+ formatTZFlag,
+ getSpatialFilterFromGeometry,
+ getEscapedTableName,
+ executeSQL,
+)
+
+def openOutputDS(def_dict : dict[str, Any]) -> gdal.Dataset:
+ """Open and return the output DS. It is created if create=False or
+ create-options is a non-empty dictionary."""
+
+ path = def_dict['path']
+ kwargs, drv = gdalSetOpenExArgs(def_dict,
+ flags=GDAL_OF_VECTOR|GDAL_OF_UPDATE|GDAL_OF_VERBOSE_ERROR)
+ try:
+ logging.debug('OpenEx(%s, %s)', path, str(kwargs))
+ return gdal.OpenEx(path, **kwargs)
+ except RuntimeError as e:
+ if not (gdal.GetLastErrorType() >= gdalconst.CE_Failure and
+ gdal.GetLastErrorNo() == gdalconst.CPLE_OpenFailed):
+ # not an open error
+ raise e
+
+ dso2 = None
+ try:
+ dso2 = gdal.OpenEx(path, nOpenFlags=GDAL_OF_ALL | GDAL_OF_UPDATE)
+ except RuntimeError:
+ pass
+ if dso2 is not None:
+ # path exists but can't be open with OpenEx(path, **kwargs)
+ raise e
+
+ try:
+ dso2 = gdal.OpenEx(path, nOpenFlags=GDAL_OF_ALL)
+ except RuntimeError:
+ pass
+ if dso2 is not None:
+ # path exists but can't be open with OpenEx(path, **kwargs)
+ raise e
+
+ dsco = def_dict.get('create-options', None)
+ if not def_dict.get('create', False) and dsco is None:
+ # not configured for creation
+ raise e
+ if drv is None or not gdalGetMetadataItem(drv, GDAL_DCAP_CREATE):
+ # not capable of creation
+ raise e
+
+ if 'open_options' in kwargs:
+ # like ogr2ogr(1)
+ logging.warning('Destination\'s open options ignored '
+ 'when creating the output datasource')
+
+ kwargs2 = { 'eType': gdalconst.GDT_Unknown }
+ if dsco is not None:
+ kwargs2['options'] = [ k + '=' + str(v) for k, v in dsco.items() ]
+
+ logging.debug('Create(%s, %s, eType=%s%s)', drv.ShortName, path, kwargs2['eType'],
+ ', options=' + str(kwargs2['options']) if 'options' in kwargs2 else '')
+ # XXX racy, a GDAL equivalent of O_EXCL would be nice
+ return drv.Create(path, 0, 0, 0, **kwargs2)
+
+def createOutputLayer(ds : gdal.Dataset,
+ layername : str,
+ srs : Optional[osr.SpatialReference] = None,
+ options : dict[str, Any]|None = None) -> ogr.Layer:
+ """Create output layer."""
+
+ if options is None or len(options) < 1:
+ raise BadConfiguration(f'Missing schema for new output layer "{layername}"')
+
+ logging.info('Creating new destination layer "%s"', layername)
+ geom_type = options['geometry-type']
+ lco = options.get('options', None)
+
+ drv = ds.GetDriver()
+ if geom_type != ogr.wkbNone and drv.ShortName == 'PostgreSQL':
+ # “Important to set to 2 for 2D layers as it has constraints on the geometry
+ # dimension during loading.”
+ # — https://gdal.org/drivers/vector/pg.html#layer-creation-options
+ if ogr.GT_HasM(geom_type):
+ if ogr.GT_HasZ(geom_type):
+ dim = 'XYZM'
+ else:
+ dim = 'XYM'
+ elif ogr.GT_HasZ(geom_type):
+ dim = '3'
+ else:
+ dim = '2'
+ if lco is None:
+ lco = []
+ lco = ['DIM=' + dim] + lco # prepend DIM=
+
+ kwargs = { 'geom_type': geom_type }
+ if srs is not None:
+ kwargs['srs'] = srs
+ if lco is not None:
+ kwargs['options'] = lco
+ logging.debug('CreateLayer(%s, geom_type="%s"%s%s)', layername,
+ ogr.GeometryTypeToName(geom_type),
+ ', srs="' + kwargs['srs'].GetName() + '"' if 'srs' in kwargs else '',
+ ', options=' + str(kwargs['options']) if 'options' in kwargs else '')
+ lyr = ds.CreateLayer(layername, **kwargs)
+ if lyr is None:
+ raise RuntimeError(f'Could not create destination layer "{layername}"')
+
+ fields = options['fields']
+ if len(fields) > 0 and not lyr.TestCapability(ogr.OLCCreateField):
+ raise RuntimeError(f'Destination layer "{layername}" lacks field creation capability')
+
+ # set up output schema
+ for fld in fields:
+ fldName = fld['Name']
+ defn = ogr.FieldDefn()
+ defn.SetName(fldName)
+
+ if 'AlternativeName' in fld:
+ v = fld['AlternativeName']
+ logging.debug('Set AlternativeName="%s" on output field "%s"', str(v), fldName)
+ defn.SetAlternativeName(v)
+
+ if 'Comment' in fld:
+ v = fld['Comment']
+ logging.debug('Set Comment="%s" on output field "%s"', str(v), fldName)
+ defn.SetComment(v)
+
+ if 'Type' in fld:
+ v = fld['Type']
+ logging.debug('Set Type=%d (%s) on output field "%s"',
+ v, ogr.GetFieldTypeName(v), fldName)
+ defn.SetType(v)
+
+ if 'SubType' in fld:
+ v = fld['SubType']
+ logging.debug('Set SubType=%d (%s) on output field "%s"',
+ v, ogr.GetFieldSubTypeName(v), fldName)
+ defn.SetSubType(v)
+
+ if 'TZFlag' in fld:
+ v = fld['TZFlag']
+ logging.debug('Set TZFlag=%d (%s) on output field "%s"',
+ v, formatTZFlag(v), fldName)
+ defn.SetTZFlag(v)
+
+ if 'Precision' in fld:
+ v = fld['Precision']
+ logging.debug('Set Precision=%d on output field "%s"', v, fldName)
+ defn.SetPrecision(v)
+
+ if 'Width' in fld:
+ v = fld['Width']
+ logging.debug('Set Width=%d on output field "%s"', v, fldName)
+ defn.SetWidth(v)
+
+ if 'Default' in fld:
+ v = fld['Default']
+ logging.debug('Set Default=%s on output field "%s"', v, fldName)
+ defn.SetDefault(v)
+
+ if 'Nullable' in fld:
+ v = fld['Nullable']
+ logging.debug('Set Nullable=%s on output field "%s"', v, fldName)
+ defn.SetNullable(v)
+
+ if 'Unique' in fld:
+ v = fld['Unique']
+ logging.debug('Set Unique=%s on output field "%s"', v, fldName)
+ defn.SetUnique(v)
+
+ if lyr.CreateField(defn, approx_ok=False) != gdalconst.CE_None:
+ raise RuntimeError(f'Could not create field "{fldName}"')
+ logging.debug('Added field "%s" to output layer "%s"', fldName, layername)
+
+ if lyr.TestCapability(ogr.OLCAlterGeomFieldDefn):
+ # it appears using .CreateLayerFromGeomFieldDefn() on a a non-nullable
+ # GeomFieldDefn doesn't do anything, so we alter it after the fact instead
+ # (GPKG doesn't support this, use GEOMETRY_NULLABLE=NO in layer creation
+ # options instead)
+ flags = drv.GetMetadataItem(gdal.DMD_ALTER_GEOM_FIELD_DEFN_FLAGS)
+ if flags is not None and 'nullable' in flags.lower().split(' '):
+ geom_field = ogr.GeomFieldDefn(None, geom_type)
+ geom_field.SetNullable(False)
+ lyr.AlterGeomFieldDefn(0, geom_field, ogr.ALTER_GEOM_FIELD_DEFN_NULLABLE_FLAG)
+
+ # TODO evaluate use external storage not main storage for geometries
+ # https://blog.cleverelephant.ca/2018/09/postgis-external-storage.html
+
+ # sync before calling StartTransaction() so we're not trying to rollback changes
+ # on a non-existing table
+ lyr.SyncToDisk()
+ return lyr
+
+def clusterLayer(lyr : ogr.Layer,
+ index_name : Optional[str] = None,
+ column_name : Optional[str] = None,
+ analyze : bool = True) -> bool:
+ """Cluster a table according to an index. If no index name is given and a column name is
+ given instead, then the index involving the least number of columns (including the given
+ column) is chosen. If neither index name or column name is given, then recluster the table
+ using the same index as before. An optional boolean (default value: True) indicates whether
+ to ANALYZE the table after clustering. See
+ https://www.postgresql.org/docs/current/sql-cluster.html .
+ Requires that the dataset driver is PostgreSQL."""
+ ds = lyr.GetDataset()
+ if ds.GetDriver().ShortName != 'PostgreSQL':
+ logging.warning('clusterLayer() called on a non-PostgreSQL dataset, ignoring')
+ return False
+
+ layername_esc = getEscapedTableName(lyr)
+ if index_name is None and column_name is not None:
+ # find out which indices involve lyr's column_name
+ with executeSQL(ds, statement='WITH indices AS ('
+ 'SELECT i.relname AS index, array_agg(a.attname) AS columns '
+ 'FROM pg_class t, pg_class i, pg_index ix, pg_attribute a '
+ 'WHERE t.oid = ix.indrelid '
+ 'AND i.oid = ix.indexrelid '
+ 'AND a.attrelid = t.oid '
+ 'AND a.attnum = ANY(ix.indkey) '
+ 'AND t.relkind = \'r\' '
+ 'AND ix.indrelid = ' + escape_literal_str(layername_esc) + '::regclass '
+ 'GROUP BY 1) '
+ 'SELECT index, array_length(columns, 1) AS len '
+ 'FROM indices '
+ 'WHERE ' + escape_literal_str(column_name) + ' = ANY(columns)'
+ # pick the index involving the least number of columns
+ 'ORDER BY 2,1 LIMIT 1') as res:
+ defn = res.GetLayerDefn()
+ i = defn.GetFieldIndex('index')
+ row = res.GetNextFeature()
+ if row is not None and row.IsFieldSetAndNotNull(i):
+ index_name = row.GetFieldAsString(i)
+
+ if index_name is None:
+ logging.warning('Layer %s has no index on column %s, cannot CLUSTER',
+ lyr.GetName(), column_name)
+ return False
+
+ statement = 'CLUSTER ' + layername_esc
+ if index_name is not None:
+ statement += ' USING ' + escape_identifier(index_name)
+ executeSQL(ds, statement=statement)
+
+ if analyze:
+ # "Because the planner records statistics about the ordering of tables, it is
+ # advisable to run ANALYZE on the newly clustered table. Otherwise, the planner
+ # might make poor choices of query plans."
+ executeSQL(ds, statement='ANALYZE ' + layername_esc)
+
+ return True
+
+# pylint: disable-next=too-many-branches
+def validateOutputLayer(lyr : ogr.Layer,
+ srs : Optional[osr.SpatialReference] = None,
+ options : Optional[dict[str, Any]] = None) -> bool:
+ """Validate the output layer against the provided SRS and creation options."""
+ ok = True
+
+ # ensure the output SRS is equivalent
+ if srs is not None:
+ srs2 = lyr.GetSpatialRef()
+ # cf. apps/ogr2ogr_lib.cpp
+ srs_options = [
+ 'IGNORE_DATA_AXIS_TO_SRS_AXIS_MAPPING=YES',
+ 'CRITERION=EQUIVALENT'
+ ]
+ if not srs.IsSame(srs2, srs_options):
+ logging.warning('Output layer "%s" has SRS %s,\nexpected %s',
+ lyr.GetName(),
+ srs2.ExportToPrettyWkt(),
+ srs.ExportToPrettyWkt())
+ ok = False
+
+ if options is None:
+ return ok
+
+ layerDefn = lyr.GetLayerDefn()
+ n = layerDefn.GetGeomFieldCount()
+ if n != 1:
+ if n == 0:
+ raise RuntimeError(f'Output layer "{lyr.GetName()}" has no geometry fields')
+ logging.warning('Output layer "%s" has %d != 1 geometry fields', lyr.GetName(), n)
+
+ iGeomField = 0
+ geomField = layerDefn.GetGeomFieldDefn(iGeomField)
+ geomType = geomField.GetType()
+ logging.debug('Geometry column #%d: name="%s\", type="%s", srs=%s, nullable=%s',
+ iGeomField, geomField.GetName(),
+ ogr.GeometryTypeToName(geomType),
+ '-' if geomField.GetSpatialRef() is None
+ else '"' + geomField.GetSpatialRef().GetName() + '"',
+ bool(geomField.IsNullable()))
+ if geomField.IsNullable():
+ logging.warning('Geometry column #%d "%s" of output layer "%s" is nullable',
+ iGeomField, geomField.GetName(), lyr.GetName())
+
+ geomType2 = options['geometry-type']
+ if geomType != geomType2:
+ logging.warning('Output layer "%s" has geometry type #%d (%s), expected #%d (%s)',
+ lyr.GetName(),
+ geomType, ogr.GeometryTypeToName(geomType),
+ geomType2, ogr.GeometryTypeToName(geomType2))
+ ok = False
+
+ fields = options.get('fields', None)
+ if fields is not None:
+ for fld in fields:
+ fldName = fld['Name']
+
+ idx = layerDefn.GetFieldIndex(fldName)
+ if idx < 0:
+ logging.warning('Output layer "%s" has no field named "%s"',
+ lyr.GetName(), fldName)
+ ok = False
+ continue
+ defn = layerDefn.GetFieldDefn(idx)
+
+ if 'AlternativeName' in fld:
+ v1 = defn.GetAlternativeName()
+ v2 = fld['AlternativeName']
+ if v1 != v2:
+ logging.warning('Field "%s" has AlternativeName="%s", expected "%s"',
+ fldName, v1, v2)
+ ok = False
+
+ if 'Comment' in fld:
+ v1 = defn.GetComment()
+ v2 = fld['Comment']
+ if v1 != v2:
+ logging.warning('Field "%s" has Comment="%s", expected "%s"',
+ fldName, v1, v2)
+ ok = False
+
+ if 'Type' in fld:
+ v1 = defn.GetType()
+ v2 = fld['Type']
+ if v1 != v2:
+ logging.warning('Field "%s" has Type=%d (%s), expected %d (%s)',
+ fldName,
+ v1, ogr.GetFieldTypeName(v1),
+ v2, ogr.GetFieldTypeName(v2))
+ ok = False
+
+ if 'SubType' in fld:
+ v1 = defn.GetSubType()
+ v2 = fld['SubType']
+ if v1 != v2:
+ logging.warning('Field "%s" has SubType=%d (%s), expected %d (%s)',
+ fldName,
+ v1, ogr.GetFieldSubTypeName(v1),
+ v2, ogr.GetFieldSubTypeName(v2))
+ ok = False
+
+ if 'TZFlag' in fld:
+ v1 = defn.GetTZFlag()
+ v2 = fld['TZFlag']
+ if v1 != v2:
+ logging.warning('Field "%s" has TZFlag=%d (%s), expected %d (%s)',
+ fldName, v1, formatTZFlag(v1), v2, formatTZFlag(v2))
+ ok = False
+
+ if 'Precision' in fld:
+ v1 = defn.GetPrecision()
+ v2 = fld['Precision']
+ if v1 != v2:
+ logging.warning('Field "%s" has Precision=%d, expected %d',
+ fldName, v1, v2)
+ ok = False
+
+ if 'Width' in fld:
+ v1 = defn.GetWidth()
+ v2 = fld['Width']
+ if v1 != v2:
+ logging.warning('Field "%s" has Width=%d, expected %d',
+ fldName, v1, v2)
+ ok = False
+
+ if 'Default' in fld:
+ v1 = defn.GetDefault()
+ v2 = fld['Default']
+ if v1 != v2:
+ logging.warning('Field "%s" has Default="%s", expected "%s"',
+ fldName, v1, v2)
+ ok = False
+
+ if 'Nullable' in fld:
+ v1 = bool(defn.IsNullable())
+ v2 = fld['Nullable']
+ if v1 != v2:
+ logging.warning('Field "%s" has Nullable=%s, expected %s',
+ fldName, v1, v2)
+ ok = False
+
+ if 'Unique' in fld:
+ v1 = bool(defn.IsUnique())
+ v2 = fld['Unique']
+ if v1 != v2:
+ logging.warning('Field "%s" has Unique=%s, expected %s',
+ fldName, v1, v2)
+ ok = False
+
+ return ok
+
+def clearLayer(lyr : ogr.Layer, identity : str = 'CONTINUE IDENTITY') -> None:
+ """Clear the given layer (wipe all its features)"""
+ n = -1
+ if lyr.TestCapability(ogr.OLCFastFeatureCount):
+ n = lyr.GetFeatureCount(force=0)
+ if n == 0:
+ # nothing to clear, we're good
+ return
+
+ ds = lyr.GetDataset()
+ if ds.GetDriver().ShortName == 'PostgreSQL':
+ # https://www.postgresql.org/docs/15/sql-truncate.html
+ statement = 'TRUNCATE TABLE {table} ' + identity + ' CASCADE'
+ op = 'Truncating'
+ else:
+ statement = 'DELETE FROM {table}'
+ op = 'Clearing'
+ logging.info('%s table %s (former feature count: %s)', op,
+ lyr.GetName(), str(n) if n >= 0 else 'unknown')
+ executeSQL(ds, statement=statement.format(table=getEscapedTableName(lyr)))
+
+def extractArchive(path : Path, destdir : str,
+ fmt : str|None = None,
+ patterns : list[str]|None = None,
+ exact_matches : list[str]|None = None) -> None:
+ """Extract an archive file into the given destination directory."""
+ if fmt is None:
+ suffix = path.suffix
+ if suffix is None or suffix == '' or not suffix.startswith('.'):
+ raise RuntimeError(f'Could not infer archive format from "{path}"')
+ fmt = suffix.removeprefix('.')
+
+ fmt = fmt.lower()
+ logging.debug('Unpacking %s archive %s into %s', fmt, path, destdir)
+
+ if fmt == 'zip':
+ import zipfile # pylint: disable=import-outside-toplevel
+ logging.debug('Opening %s as ZipFile', path)
+ with zipfile.ZipFile(path, mode='r') as z:
+ namelist = listArchiveMembers(z.namelist(),
+ patterns=patterns,
+ exact_matches=exact_matches)
+ z.extractall(path=destdir, members=namelist)
+ else:
+ raise RuntimeError(f'Unknown archive format "{fmt}"')
+
+def listArchiveMembers(namelist : list[str],
+ patterns : list[str]|None = None,
+ exact_matches : list[str]|None = None) -> list[str]:
+ """List archive members matching the given parterns and/or exact matches."""
+ if patterns is None and exact_matches is None:
+ # if neither patterns nor exact_matches are given we'll extract the entire archive
+ return namelist
+ if patterns is None:
+ patterns = []
+ if exact_matches is None:
+ exact_matches = []
+
+ members = []
+ for name in namelist:
+ ok = False
+ if name in exact_matches:
+ # try exact matches first
+ logging.debug('Listed archive member %s (exact match)', name)
+ members.append(name)
+ ok = True
+ continue
+ # if there are no exact matches, try patterns one by one in the supplied order
+ for pat in patterns:
+ if fnmatchcase(name, pat):
+ logging.debug('Listed archive member %s (matching pattern "%s")', name, pat)
+ members.append(name)
+ ok = True
+ break
+ if not ok:
+ logging.debug('Ignoring archive member %s', name)
+ return members
+
+@enum_unique
+class ImportStatus(Enum):
+ """Return value for importSources(): success, error, or no-change."""
+ IMPORT_SUCCESS = 0
+ IMPORT_ERROR = 1
+ IMPORT_NOCHANGE = 255
+
+ def __str__(self):
+ return self.name.removeprefix('IMPORT_')
+
+# pylint: disable-next=dangerous-default-value
+def importSources(lyr : ogr.Layer,
+ sources : dict[str,Any] = {},
+ cachedir : Path|None = None,
+ extent : ogr.Geometry|None = None,
+ dsoTransaction : bool = False,
+ lyrcache : ogr.Layer|None = None,
+ force : bool = False,
+ cluster_geometry : bool = False) -> ImportStatus:
+ """Clear lyr and import source layers to it."""
+
+ dso = lyr.GetDataset()
+ layername = lyr.GetName()
+ if dsoTransaction:
+ # declare a SAVEPOINT (nested transaction) within the DS-level transaction
+ lyrTransaction = 'SAVEPOINT ' + escape_identifier('savept_' + layername)
+ executeSQL(dso, lyrTransaction)
+ elif lyr.TestCapability(ogr.OLCTransactions):
+ # try to start transaction on the layer
+ logging.debug('Starting transaction on output layer "%s"', layername)
+ lyrTransaction = lyr.StartTransaction() == ogr.OGRERR_NONE
+ if not lyrTransaction:
+ logging.warning('Couldn\'t start transaction on output layer "%s"', layername)
+ else:
+ logging.warning('Unsafe update, output layer "%s" doesn\'t support transactions',
+ layername)
+ lyrTransaction = False
+
+ rv = ImportStatus.IMPORT_NOCHANGE
+ now = datetime.now().astimezone()
+ try:
+ clearLayer(lyr) # TODO conditional (only if not new)?
+
+ for source in sources:
+ importSource0(lyr, **source['source'],
+ args=source['import'],
+ cachedir=cachedir,
+ extent=extent,
+ callback=_importSource2)
+
+ # force the PG driver to call EndCopy() to detect errors and trigger a
+ # rollback if needed
+ dso.FlushCache()
+
+ if lyrcache is None:
+ rv = ImportStatus.IMPORT_SUCCESS
+ elif updateLayerCache(cache=lyrcache,
+ lyr=lyr,
+ force=force,
+ lyrTransaction=lyrTransaction,
+ last_updated=now):
+ rv = ImportStatus.IMPORT_SUCCESS
+ else:
+ rv = ImportStatus.IMPORT_NOCHANGE
+ if isinstance(lyrTransaction, bool):
+ # the transaction on lyr was already rolled back
+ lyrTransaction = False
+
+ if (rv == ImportStatus.IMPORT_SUCCESS and cluster_geometry
+ and lyr.GetLayerDefn().GetGeomType() != ogr.wkbNone):
+ clusterLayer(lyr, column_name=lyr.GetGeometryColumn())
+
+ except Exception: # pylint: disable=broad-exception-caught
+ rv = ImportStatus.IMPORT_ERROR
+ if isinstance(lyrTransaction, str):
+ statement = 'ROLLBACK TO ' + lyrTransaction
+ logging.exception('Exception occured within transaction')
+ # don't unset lyrTransaction here as we want to RELEASE SAVEPOINT
+ try:
+ executeSQL(dso, statement=statement)
+ except Exception: # pylint: disable=broad-exception-caught
+ logging.exception('Could not execute SQL: %s', statement)
+ elif isinstance(lyrTransaction, bool) and lyrTransaction:
+ logging.exception('Exception occured within transaction on output '
+ 'layer "%s": ROLLBACK', layername)
+ lyrTransaction = None
+ try:
+ if lyr.RollbackTransaction() != ogr.OGRERR_NONE:
+ logging.error('Could not rollback transaction on layer "%s"', layername)
+ except Exception: # pylint: disable=broad-exception-caught
+ logging.exception('Could not rollback transaction on layer "%s"', layername)
+ else:
+ traceback.print_exc()
+
+ finally:
+ if isinstance(lyrTransaction, str):
+ statement = 'RELEASE ' + lyrTransaction
+ try:
+ executeSQL(dso, statement)
+ except Exception: # pylint: disable=broad-exception-caught
+ rv = ImportStatus.IMPORT_ERROR
+ logging.exception('Could not execute SQL: %s', statement)
+ elif isinstance(lyrTransaction, bool) and lyrTransaction:
+ try:
+ if lyr.CommitTransaction() != ogr.OGRERR_NONE:
+ rv = ImportStatus.IMPORT_ERROR
+ logging.error('Could not commit transaction')
+ except Exception: # pylint: disable=broad-exception-caught
+ rv = ImportStatus.IMPORT_ERROR
+ logging.exception('Could not commit transaction on layer "%s"', layername)
+ return rv
+
+# pylint: disable-next=dangerous-default-value
+def importSource0(lyr : ogr.Layer|None = None,
+ path : str = '/nonexistent',
+ unar : dict[str,Any]|None = None,
+ args : dict[str,Any] = {},
+ cachedir : Path|None = None,
+ extent : ogr.Geometry|None = None,
+ callback : Callable[[ogr.Layer|None, str, dict[str,Any], Path|None,
+ ogr.Geometry|None], None]|None = None) -> None:
+ """Import a source layer"""
+ if unar is None:
+ return callback(lyr, path, args=args, basedir=cachedir, extent=extent)
+
+ ds_srcpath = Path(args['path'])
+ if ds_srcpath.is_absolute():
+ # treat absolute paths as relative to the archive root
+ logging.warning('%s is absolute, removing leading anchor', ds_srcpath)
+ ds_srcpath = ds_srcpath.relative_to(ds_srcpath.anchor)
+
+ ds_srcpath = str(ds_srcpath)
+ with tempfile.TemporaryDirectory() as tmpdir:
+ logging.debug('Created temporary directory %s', tmpdir)
+ extractArchive(Path(path) if cachedir is None else cachedir.joinpath(path),
+ tmpdir,
+ fmt=unar.get('format', None),
+ patterns=unar.get('patterns', None),
+ exact_matches=[ds_srcpath])
+ return callback(lyr, ds_srcpath, args=args, basedir=Path(tmpdir), extent=extent)
+
+def setFieldMapValue(fld : ogr.FieldDefn,
+ idx : int,
+ val : None|int|str|bytes|float) -> None|int|str|bytes|float:
+ """Validate field value mapping."""
+ if val is None:
+ if not fld.IsNullable():
+ logging.warning('Field "%s" is not NULLable but remaps NULL', fld.GetName())
+ return None
+
+ fldType = fld.GetType()
+ if fldType in (ogr.OFTInteger, ogr.OFTInteger64):
+ if isinstance(val, int):
+ return val
+ elif fldType == ogr.OFTString:
+ if isinstance(val, str):
+ return val
+ elif fldType == ogr.OFTBinary:
+ if isinstance(val, bytes):
+ return val
+ elif fldType == ogr.OFTReal:
+ if isinstance(val, int):
+ return float(val)
+ if isinstance(val, float):
+ return val
+
+ raise RuntimeError(f'Field "{fld.GetName()}" mapping #{idx} has incompatible type '
+ f'for {ogr.GetFieldTypeName(fldType)}')
+
+def _importSource2(lyr_dst : ogr.Layer, path : str, args : dict[str,Any],
+ basedir : Path|None, extent : ogr.Geometry|None) -> None:
+ """Import a source layer (already extracted)
+ This is more or less like ogr2ogr/GDALVectorTranslate() but we roll
+ out our own (slower) version because GDALVectorTranslate() insists in
+ calling StartTransaction() https://github.com/OSGeo/gdal/issues/3403
+ while we want a single transaction for the entire desination layer,
+ including truncation, source imports, and metadata changes."""
+ kwargs, _ = gdalSetOpenExArgs(args, flags=GDAL_OF_VECTOR|GDAL_OF_READONLY|GDAL_OF_VERBOSE_ERROR)
+ path2 = path if basedir is None else str(basedir.joinpath(path))
+
+ logging.debug('OpenEx(%s, %s)', path2, str(kwargs))
+ ds = gdal.OpenEx(path2, **kwargs)
+ if ds is None:
+ raise RuntimeError(f'Could not open {path2}')
+
+ layername = args.get('layername', None)
+ if layername is None:
+ idx = 0
+ lyr = ds.GetLayerByIndex(idx)
+ msg = '#' + str(idx)
+ if lyr is not None:
+ layername = lyr.GetName()
+ msg += ' ("' + layername + '")'
+ else:
+ lyr = ds.GetLayerByName(layername)
+ msg = '"' + layername + '"'
+ if lyr is None:
+ raise RuntimeError(f'Could not get requested layer {msg} from {path2}')
+
+ logging.info('Importing layer %s from "%s"', msg, path)
+ importLayer(lyr_dst, lyr, args=args, extent=extent)
+
+# pylint: disable-next=too-many-branches, too-many-statements, dangerous-default-value
+def importLayer(lyr_dst : ogr.Layer, lyr : ogr.Layer,
+ args : dict[str,Any] = {},
+ extent : Optional[ogr.Geometry] = None) -> None:
+ """Import a source layer (already opened)."""
+ layername = lyr.GetName()
+ srs = lyr.GetSpatialRef()
+ if srs is None:
+ raise RuntimeError(f'Source layer {layername} has no SRS')
+
+ srs_dst = lyr_dst.GetSpatialRef()
+ if srs_dst is None:
+ logging.warning('Destination has no SRS, skipping coordinate transformation')
+ ct = None
+ elif srs_dst.IsSame(srs):
+ logging.debug('Both source and destination have the same SRS (%s), '
+ 'skipping coordinate transformation',
+ srs_dst.GetName())
+ ct = None
+ else:
+ # TODO Use GetSupportedSRSList() and SetActiveSRS() with GDAL ≥3.7.0
+ # when possible, see apps/ogr2ogr_lib.cpp
+ # pylint: disable=duplicate-code
+ logging.debug('Creating transforming from source SRS (%s) to destination SRS (%s)',
+ srs.GetName(), srs_dst.GetName())
+ ct = osr.CoordinateTransformation(srs, srs_dst)
+ if ct is None:
+ raise RuntimeError(f'Could not create transformation from source SRS ({srs.GetName()}) '
+ f'to destination SRS ({srs_dst.GetName()})')
+
+ defn = lyr.GetLayerDefn()
+ geomFieldCount = defn.GetGeomFieldCount()
+ if geomFieldCount != 1:
+ # TODO Add support for multiple geometry fields (also in the fingerprinting logic below)
+ logging.warning('Source layer "%s" has %d != 1 geometry fields', layername, geomFieldCount)
+
+ fieldCount = defn.GetFieldCount()
+ fieldMap = [-1] * fieldCount
+ fields = args['field-map']
+ fieldSet = set()
+
+ canIgnoreFields = lyr.TestCapability(ogr.OLCIgnoreFields)
+ ignoredFieldNames = []
+ for i in range(fieldCount):
+ fld = defn.GetFieldDefn(i)
+ fldName = fld.GetName()
+ fieldMap[i] = v = fields.get(fldName, -1)
+ fieldSet.add(fldName)
+
+ if v < 0 and canIgnoreFields:
+ # call SetIgnored() on unwanted source fields
+ logging.debug('Set Ignored=True on output field "%s"', fldName)
+ fld.SetIgnored(True)
+ ignoredFieldNames.append(fldName)
+
+ count0 = -1
+ if lyr.TestCapability(ogr.OLCFastFeatureCount):
+ count0 = lyr.GetFeatureCount(force=0)
+
+ if count0 == 0 and len(fieldSet) == 0:
+ # skip the below warning in some cases (e.g., GeoJSON source)
+ logging.info('Source layer "%s" has no fields nor features, skipping', layername)
+ return
+
+ logging.debug('Field map: %s', str(fieldMap))
+ for fld in fields:
+ if not fld in fieldSet:
+ logging.warning('Source layer "%s" has no field named "%s", ignoring', layername, fld)
+
+ count1 = -1
+ if args.get('spatial-filter', True) and extent is not None:
+ spatialFilter = getSpatialFilterFromGeometry(extent, srs)
+ logging.debug('Setting spatial filter to %s', spatialFilter.ExportToWkt())
+ lyr.SetSpatialFilter(spatialFilter)
+
+ if lyr.TestCapability(ogr.OLCFastFeatureCount):
+ count1 = lyr.GetFeatureCount(force=0)
+
+ if count0 >= 0:
+ if count1 >= 0:
+ logging.info('Source layer "%s" has %d features (%d of which intersecting extent)',
+ layername, count0, count1)
+ else:
+ logging.info('Source layer "%s" has %d features', layername, count0)
+
+ logging.info('Ignored fields from source layer: %s',
+ '-' if len(ignoredFieldNames) == 0 else ', '.join(ignoredFieldNames))
+
+ # build a list of triplets (field index, replacement_for_null, [(from_value, to_value), …])
+ valueMap = []
+ for fldName, rules in args.get('value-map', {}).items():
+ i = defn.GetFieldIndex(fldName)
+ if i < 0:
+ raise RuntimeError(f'Source layer "{layername}" has no field named "{fldName}"')
+ if fieldMap[i] < 0:
+ logging.warning('Ignored source field "%s" has value map', fldName)
+ continue
+
+ hasNullReplacement = False
+ nullReplacement = None
+ mapping = []
+ fld = defn.GetFieldDefn(i)
+ for idx, (rFrom, rTo) in enumerate(rules):
+ # use fld for both 'from' and 'to' (the types must match, casting is not
+ # allowed in the mapping)
+ if rFrom is None:
+ if hasNullReplacement:
+ logging.warning('Field "%s" has duplicate NULL replacement',
+ fld.GetName())
+ else:
+ setFieldMapValue(fld, idx, None) # validate NULL
+ rTo = setFieldMapValue(fld, idx, rTo)
+ hasNullReplacement = True
+ nullReplacement = rTo
+ elif isinstance(rFrom, re.Pattern):
+ # validate but keep the rFrom regex
+ setFieldMapValue(fld, idx, str(rFrom))
+ rTo = setFieldMapValue(fld, idx, rTo)
+ mapping.append( (rFrom, rTo, 1) )
+ else:
+ rFrom = setFieldMapValue(fld, idx, rFrom)
+ rTo = setFieldMapValue(fld, idx, rTo)
+ mapping.append( (rFrom, rTo, 0) )
+
+ if nullReplacement is not None or len(mapping) > 0:
+ valueMap.append( (i, nullReplacement, mapping) )
+
+ if args.get('rstrip-strings', False):
+ stringFieldsIdx = [ i for i in range(fieldCount)
+ if defn.GetFieldDefn(i).GetType() == ogr.OFTString and
+ fieldMap[i] >= 0 ]
+ logging.debug('Source field indices to rstrip: %s', str(stringFieldsIdx))
+ bStringFields = len(stringFieldsIdx) > 0
+ else:
+ bStringFields = False
+
+ bValueMap = len(valueMap) > 0
+ defn = None
+
+ defn_dst = lyr_dst.GetLayerDefn()
+ eGType_dst = defn_dst.GetGeomType()
+ eGType_dst_HasZ = ogr.GT_HasZ(eGType_dst)
+ eGType_dst_HasM = ogr.GT_HasM(eGType_dst)
+ dGeomIsUnknown = ogr.GT_Flatten(eGType_dst) == ogr.wkbUnknown
+
+ if bValueMap:
+ valueMapCounts = [0] * fieldCount
+
+ featureCount = 0
+ mismatch = {}
+ feature = lyr.GetNextFeature()
+ while feature is not None:
+ if bStringFields:
+ for i in stringFieldsIdx:
+ if feature.IsFieldSetAndNotNull(i):
+ v = feature.GetField(i)
+ feature.SetField(i, v.rstrip())
+
+ if bValueMap:
+ for i, nullReplacement, mapping in valueMap:
+ if not feature.IsFieldSet(i):
+ continue
+ if feature.IsFieldNull(i):
+ if nullReplacement is not None:
+ # replace NULL with non-NULL value
+ feature.SetField(i, nullReplacement)
+ valueMapCounts[i] += 1
+ continue
+
+ v = feature.GetField(i)
+ for rFrom, rTo, rType in mapping:
+ if rType == 0:
+ # literal
+ if v != rFrom:
+ continue
+ elif rType == 1:
+ # regex
+ m = rFrom.fullmatch(v)
+ if m is None:
+ continue
+ if rTo is not None:
+ rTo = rTo.format(*m.groups())
+ else:
+ raise RuntimeError(str(rType))
+
+ if rTo is None:
+ # replace non-NULL value with NULL
+ feature.SetFieldNull(i)
+ else:
+ # replace non-NULL value with non-NULL value
+ feature.SetField(i, rTo)
+ valueMapCounts[i] += 1
+ break
+
+ feature2 = ogr.Feature(defn_dst)
+ feature2.SetFromWithMap(feature, False, fieldMap)
+
+ geom = feature2.GetGeometryRef()
+ if geom is None:
+ if eGType_dst != ogr.wkbNone:
+ logging.warning('Source feature #%d has no geometry, trying to transfer anyway',
+ feature.GetFID())
+ else:
+ if ct is not None and geom.Transform(ct) != ogr.OGRERR_NONE:
+ raise RuntimeError('Could not apply coordinate transformation')
+
+ eGType = geom.GetGeometryType()
+ if eGType != eGType_dst and not dGeomIsUnknown:
+ # Promote to multi, cf. apps/ogr2ogr_lib.cpp:ConvertType()
+ eGType2 = eGType
+ if eGType in (ogr.wkbTriangle, ogr.wkbTIN, ogr.wkbPolyhedralSurface):
+ eGType2 = ogr.wkbMultiPolygon
+ elif not ogr.GT_IsSubClassOf(eGType, ogr.wkbGeometryCollection):
+ eGType2 = ogr.GT_GetCollection(eGType)
+
+ eGType2 = ogr.GT_SetModifier(eGType2, eGType_dst_HasZ, eGType_dst_HasM)
+ if eGType2 == eGType_dst:
+ mismatch[eGType] = mismatch.get(eGType, 0) + 1
+ geom = ogr.ForceTo(geom, eGType_dst)
+ # TODO call MakeValid()?
+ else:
+ raise RuntimeError(f'Conversion from {ogr.GeometryTypeToName(eGType)} '
+ f'to {ogr.GeometryTypeToName(eGType_dst)} not implemented')
+ feature2.SetGeometryDirectly(geom)
+
+ if lyr_dst.CreateFeature(feature2) != ogr.OGRERR_NONE:
+ raise RuntimeError(f'Could not transfer source feature #{feature.GetFID()}')
+
+ featureCount += 1
+ feature = lyr.GetNextFeature()
+
+ if bValueMap:
+ valueMapCounts = [ (lyr.GetLayerDefn().GetFieldDefn(i).GetName(), k)
+ for i,k in enumerate(valueMapCounts) if k > 0 ]
+
+ lyr = None
+ logging.info('Imported %d features from source layer "%s"', featureCount, layername)
+
+ if bValueMap:
+ if len(valueMapCounts) > 0:
+ valueMapCounts = ', '.join([ str(k) + '× "' + n + '"' for n,k in valueMapCounts ])
+ else:
+ valueMapCounts = '-'
+ logging.info('Field substitutions: %s', valueMapCounts)
+
+ if len(mismatch) > 0:
+ mismatches = [ str(n) + '× ' + ogr.GeometryTypeToName(t)
+ for t,n in sorted(mismatch.items(), key=lambda x: x[1]) ]
+ logging.info('Forced conversion to %s: %s',
+ ogr.GeometryTypeToName(eGType_dst), ', '.join(mismatches))
+
+def listFieldsOrderBy(defn : ogr.FeatureDefn,
+ unique : bool|None = None,
+ nullable : bool|None = None) -> Iterator[str]:
+ """Return an iterator of column names suitable for ORDER BY."""
+ fields_str = {}
+ for i in range(defn.GetFieldCount()):
+ fld = defn.GetFieldDefn(i)
+ if (fld.IsIgnored() or
+ # if 'unique' or 'unable' is not None then skip the field
+ # unless the boolean matches .IsUnique() resp. .IsNullable()
+ not (unique is None or fld.IsUnique() == unique) or
+ not (nullable is None or fld.IsNullable() == nullable)):
+ continue
+ if fld.GetType() in (ogr.OFTInteger, ogr.OFTInteger64):
+ # list integers first
+ yield fld.GetName()
+ elif fld.GetType() == ogr.OFTString:
+ w = fld.GetWidth()
+ if 0 < w < 256:
+ # only consider short-ish strings
+ fields_str[fld.GetName()] = w
+ # order string columns by width
+ for c,_ in sorted(fields_str.items(), key=lambda x:x[1]):
+ yield c
+
+# pylint: disable-next=too-many-branches, too-many-statements
+def updateLayerCache(lyr : ogr.Layer, cache : ogr.Layer,
+ last_updated : datetime,
+ lyrTransaction : str|bool|None = None,
+ force : bool = False) -> bool:
+ """Update attributes in the layer cache for the given layer name.
+ Return a boolean indicating whether changes to layername were *not*
+ rolled back (hence might still be outstanding in the transaction)."""
+ layername = lyr.GetName()
+
+ dgst = sha256()
+ defn = lyr.GetLayerDefn()
+ fields = []
+ for i in range(defn.GetFieldCount()):
+ fields.append('t.' + escape_identifier(defn.GetFieldDefn(i).GetName()))
+ if len(fields) == 0:
+ fields = ['0 AS hash_properties']
+ else:
+ fields = [ 'hash_record_extended(ROW(' + ','.join(fields) + '),0) AS hash_properties' ]
+
+ fidColumn = lyr.GetFIDColumn()
+ if fidColumn is None or fidColumn == '':
+ raise RuntimeError(f'Couldn\'t find FID column for "{layername}"')
+ # defn.GetGeomFieldCount() != 1 is not supported and yields a warning in _importSource2()
+ geometryColumn = lyr.GetGeometryColumn()
+ if geometryColumn is None or geometryColumn == '':
+ raise RuntimeError(f'Couldn\'t find geometry column for "{layername}"')
+
+ fields.append('sha256(COALESCE(' +
+ 'ST_AsEWKB(t.' + escape_identifier(geometryColumn) + '),' +
+ '\'\')) AS hash_geom')
+ if len(fields) == 0:
+ raise RuntimeError('Empty field list in SELECT')
+ query = 'SELECT ' + ','.join(fields) + ' FROM ' + getEscapedTableName(lyr) + ' t'
+
+ sort_by = next(listFieldsOrderBy(defn, unique=True, nullable=False), None)
+ if sort_by is not None:
+ sort_by = [ sort_by ]
+ else:
+ count = lyr.GetFeatureCount(force=0)
+ if count is None or count < 0 or count > 5000:
+ logging.warning('Layer "%s" has many (%s) features but no UNIQUE NOT NULL constraint, '
+ 'sorting might be unstable and slow', layername,
+ str(count) if (count is not None and count >= 0) else 'N/A')
+ sort_by = list(listFieldsOrderBy(defn)) + [ geometryColumn, fidColumn ]
+ query += ' ORDER BY ' + ','.join(['t.' + escape_identifier(c) for c in sort_by])
+
+ struct_dgst : Final = struct.Struct('@qq').pack
+ ds = lyr.GetDataset()
+ with executeSQL(ds, query) as lyr2:
+ defn2 = lyr2.GetLayerDefn()
+ assert defn2.GetFieldDefn(0).GetName() == 'hash_properties'
+ assert defn2.GetFieldDefn(1).GetName() == 'hash_geom'
+ feature = lyr2.GetNextFeature()
+ while feature is not None:
+ dgst.update(struct_dgst(feature.GetFID(), feature.GetFieldAsInteger64(0)))
+ dgst.update(feature.GetFieldAsBinary(1))
+ feature = lyr2.GetNextFeature()
+ fingerprint = dgst.digest()
+
+ attributeFilter = 'layername = ' + escape_literal_str(layername)
+ logging.debug('SetAttributeFilter("%s", "%s")', cache.GetName(), attributeFilter)
+ cache.SetAttributeFilter(attributeFilter)
+
+ feature = cache.GetNextFeature()
+ if feature is None:
+ # not in cache
+ logging.debug('Creating new feature in layer cache for %s', attributeFilter)
+ update = False
+ feature = ogr.Feature(cache.GetLayerDefn())
+ feature.SetFieldString(0, layername)
+ fingerprint_old = None
+ else:
+ logging.debug('Updating existing feature in layer cache for %s', attributeFilter)
+ update = True
+ fingerprint_old = feature.GetFieldAsBinary(2) if feature.IsFieldSetAndNotNull(2) else None
+ assert cache.GetNextFeature() is None
+
+ if last_updated.tzinfo == UTC:
+ tzFlag = ogr.TZFLAG_UTC
+ else:
+ td = last_updated.utcoffset()
+ # 15min increments/decrements per unit above/below UTC, cf.
+ # https://gdal.org/en/stable/api/vector_c_api.html#c.OGR_TZFLAG_UTC
+ tzFlag = td.days * 96 + td.seconds // 900
+ if timedelta(seconds=tzFlag*900) != td or abs(tzFlag) > 56: # max ±14:00
+ raise RuntimeError(f'Invalid UTC offset {td}')
+ tzFlag += ogr.TZFLAG_UTC
+
+ feature.SetField(1, last_updated.year,
+ last_updated.month,
+ last_updated.day,
+ last_updated.hour,
+ last_updated.minute,
+ float(last_updated.second) + float(last_updated.microsecond)/1000000.,
+ tzFlag)
+ if fingerprint is None:
+ feature.SetFieldNull(2)
+ else:
+ # https://lists.osgeo.org/pipermail/gdal-dev/2020-December/053170.html
+ feature.SetFieldBinaryFromHexString(2, fingerprint.hex())
+
+ ret = True
+ if update:
+ if (fingerprint is None or fingerprint_old is None or fingerprint != fingerprint_old):
+ logging.info('Updated layer "%s" has new fingerprint %s', layername,
+ fingerprint.hex()[:8] if fingerprint is not None else 'N/A')
+ elif force:
+ logging.info('Updated layer "%s" has identical fingerprint %s',
+ layername, fingerprint.hex()[:8])
+ else:
+ # no change: rollback (*before* updating the cache) if possible to retain FID values
+ if isinstance(lyrTransaction, str):
+ logging.info('Updated layer "%s" has identical fingerprint %s, rolling back',
+ layername, fingerprint.hex()[:8])
+ try:
+ executeSQL(ds, 'ROLLBACK TO ' + lyrTransaction)
+ except Exception: # pylint: disable=broad-exception-caught
+ logging.exception('Could not execute SQL: %s', query)
+ else:
+ ret = False
+ elif isinstance(lyrTransaction, bool) and lyrTransaction:
+ logging.info('Updated layer "%s" has identical fingerprint %s, rolling back',
+ layername, fingerprint.hex()[:8])
+ try:
+ if lyr.RollbackTransaction() == ogr.OGRERR_NONE:
+ ret = False
+ else:
+ logging.error('Could not rollback transaction on layer "%s"',
+ layername)
+ except Exception: # pylint: disable=broad-exception-caught
+ logging.exception('Could not rollback transaction on layer "%s"',
+ layername)
+ else:
+ logging.info('Updated layer "%s" has identical fingerprint %s',
+ layername, fingerprint.hex()[:8])
+
+ if cache.UpdateFeature(feature, [1,2], [], False) != ogr.OGRERR_NONE:
+ raise RuntimeError('Could not update feature in layer cache')
+ else:
+ if cache.CreateFeature(feature) != ogr.OGRERR_NONE:
+ raise RuntimeError('Could not create new feature in layer cache')
+
+ # force the PG driver to call EndCopy() to detect errors and trigger a
+ # rollback if needed
+ ds.FlushCache()
+ return ret