diff options
Diffstat (limited to 'import_source.py')
| -rw-r--r-- | import_source.py | 1153 |
1 files changed, 1153 insertions, 0 deletions
diff --git a/import_source.py b/import_source.py new file mode 100644 index 0000000..f802a57 --- /dev/null +++ b/import_source.py @@ -0,0 +1,1153 @@ +#!/usr/bin/python3 + +#---------------------------------------------------------------------- +# Backend utilities for the Klimatanalys Norr project (import source layers) +# Copyright © 2024-2025 Guilhem Moulin <info@guilhem.se> +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. +#---------------------------------------------------------------------- + +# pylint: disable=invalid-name, missing-module-docstring, fixme + +import logging +import tempfile +import re +from fnmatch import fnmatchcase +from pathlib import Path +from datetime import datetime, timedelta, UTC +from typing import Any, Callable, Final, Iterator, Optional +import traceback +from enum import Enum, unique as enum_unique +from hashlib import sha256 +import struct + +from osgeo import gdal, ogr, osr +from osgeo.gdalconst import ( + OF_ALL as GDAL_OF_ALL, + OF_READONLY as GDAL_OF_READONLY, + OF_UPDATE as GDAL_OF_UPDATE, + OF_VECTOR as GDAL_OF_VECTOR, + OF_VERBOSE_ERROR as GDAL_OF_VERBOSE_ERROR, + DCAP_CREATE as GDAL_DCAP_CREATE, +) +from osgeo import gdalconst + +from common import BadConfiguration, escape_identifier, escape_literal_str +from common_gdal import ( + gdalSetOpenExArgs, + gdalGetMetadataItem, + formatTZFlag, + getSpatialFilterFromGeometry, + getEscapedTableName, + executeSQL, +) + +def openOutputDS(def_dict : dict[str, Any]) -> gdal.Dataset: + """Open and return the output DS. It is created if create=False or + create-options is a non-empty dictionary.""" + + path = def_dict['path'] + kwargs, drv = gdalSetOpenExArgs(def_dict, + flags=GDAL_OF_VECTOR|GDAL_OF_UPDATE|GDAL_OF_VERBOSE_ERROR) + try: + logging.debug('OpenEx(%s, %s)', path, str(kwargs)) + return gdal.OpenEx(path, **kwargs) + except RuntimeError as e: + if not (gdal.GetLastErrorType() >= gdalconst.CE_Failure and + gdal.GetLastErrorNo() == gdalconst.CPLE_OpenFailed): + # not an open error + raise e + + dso2 = None + try: + dso2 = gdal.OpenEx(path, nOpenFlags=GDAL_OF_ALL | GDAL_OF_UPDATE) + except RuntimeError: + pass + if dso2 is not None: + # path exists but can't be open with OpenEx(path, **kwargs) + raise e + + try: + dso2 = gdal.OpenEx(path, nOpenFlags=GDAL_OF_ALL) + except RuntimeError: + pass + if dso2 is not None: + # path exists but can't be open with OpenEx(path, **kwargs) + raise e + + dsco = def_dict.get('create-options', None) + if not def_dict.get('create', False) and dsco is None: + # not configured for creation + raise e + if drv is None or not gdalGetMetadataItem(drv, GDAL_DCAP_CREATE): + # not capable of creation + raise e + + if 'open_options' in kwargs: + # like ogr2ogr(1) + logging.warning('Destination\'s open options ignored ' + 'when creating the output datasource') + + kwargs2 = { 'eType': gdalconst.GDT_Unknown } + if dsco is not None: + kwargs2['options'] = [ k + '=' + str(v) for k, v in dsco.items() ] + + logging.debug('Create(%s, %s, eType=%s%s)', drv.ShortName, path, kwargs2['eType'], + ', options=' + str(kwargs2['options']) if 'options' in kwargs2 else '') + # XXX racy, a GDAL equivalent of O_EXCL would be nice + return drv.Create(path, 0, 0, 0, **kwargs2) + +def createOutputLayer(ds : gdal.Dataset, + layername : str, + srs : Optional[osr.SpatialReference] = None, + options : dict[str, Any]|None = None) -> ogr.Layer: + """Create output layer.""" + + if options is None or len(options) < 1: + raise BadConfiguration(f'Missing schema for new output layer "{layername}"') + + logging.info('Creating new destination layer "%s"', layername) + geom_type = options['geometry-type'] + lco = options.get('options', None) + + drv = ds.GetDriver() + if geom_type != ogr.wkbNone and drv.ShortName == 'PostgreSQL': + # “Important to set to 2 for 2D layers as it has constraints on the geometry + # dimension during loading.” + # — https://gdal.org/drivers/vector/pg.html#layer-creation-options + if ogr.GT_HasM(geom_type): + if ogr.GT_HasZ(geom_type): + dim = 'XYZM' + else: + dim = 'XYM' + elif ogr.GT_HasZ(geom_type): + dim = '3' + else: + dim = '2' + if lco is None: + lco = [] + lco = ['DIM=' + dim] + lco # prepend DIM= + + kwargs = { 'geom_type': geom_type } + if srs is not None: + kwargs['srs'] = srs + if lco is not None: + kwargs['options'] = lco + logging.debug('CreateLayer(%s, geom_type="%s"%s%s)', layername, + ogr.GeometryTypeToName(geom_type), + ', srs="' + kwargs['srs'].GetName() + '"' if 'srs' in kwargs else '', + ', options=' + str(kwargs['options']) if 'options' in kwargs else '') + lyr = ds.CreateLayer(layername, **kwargs) + if lyr is None: + raise RuntimeError(f'Could not create destination layer "{layername}"') + + fields = options['fields'] + if len(fields) > 0 and not lyr.TestCapability(ogr.OLCCreateField): + raise RuntimeError(f'Destination layer "{layername}" lacks field creation capability') + + # set up output schema + for fld in fields: + fldName = fld['Name'] + defn = ogr.FieldDefn() + defn.SetName(fldName) + + if 'AlternativeName' in fld: + v = fld['AlternativeName'] + logging.debug('Set AlternativeName="%s" on output field "%s"', str(v), fldName) + defn.SetAlternativeName(v) + + if 'Comment' in fld: + v = fld['Comment'] + logging.debug('Set Comment="%s" on output field "%s"', str(v), fldName) + defn.SetComment(v) + + if 'Type' in fld: + v = fld['Type'] + logging.debug('Set Type=%d (%s) on output field "%s"', + v, ogr.GetFieldTypeName(v), fldName) + defn.SetType(v) + + if 'SubType' in fld: + v = fld['SubType'] + logging.debug('Set SubType=%d (%s) on output field "%s"', + v, ogr.GetFieldSubTypeName(v), fldName) + defn.SetSubType(v) + + if 'TZFlag' in fld: + v = fld['TZFlag'] + logging.debug('Set TZFlag=%d (%s) on output field "%s"', + v, formatTZFlag(v), fldName) + defn.SetTZFlag(v) + + if 'Precision' in fld: + v = fld['Precision'] + logging.debug('Set Precision=%d on output field "%s"', v, fldName) + defn.SetPrecision(v) + + if 'Width' in fld: + v = fld['Width'] + logging.debug('Set Width=%d on output field "%s"', v, fldName) + defn.SetWidth(v) + + if 'Default' in fld: + v = fld['Default'] + logging.debug('Set Default=%s on output field "%s"', v, fldName) + defn.SetDefault(v) + + if 'Nullable' in fld: + v = fld['Nullable'] + logging.debug('Set Nullable=%s on output field "%s"', v, fldName) + defn.SetNullable(v) + + if 'Unique' in fld: + v = fld['Unique'] + logging.debug('Set Unique=%s on output field "%s"', v, fldName) + defn.SetUnique(v) + + if lyr.CreateField(defn, approx_ok=False) != gdalconst.CE_None: + raise RuntimeError(f'Could not create field "{fldName}"') + logging.debug('Added field "%s" to output layer "%s"', fldName, layername) + + if lyr.TestCapability(ogr.OLCAlterGeomFieldDefn): + # it appears using .CreateLayerFromGeomFieldDefn() on a a non-nullable + # GeomFieldDefn doesn't do anything, so we alter it after the fact instead + # (GPKG doesn't support this, use GEOMETRY_NULLABLE=NO in layer creation + # options instead) + flags = drv.GetMetadataItem(gdal.DMD_ALTER_GEOM_FIELD_DEFN_FLAGS) + if flags is not None and 'nullable' in flags.lower().split(' '): + geom_field = ogr.GeomFieldDefn(None, geom_type) + geom_field.SetNullable(False) + lyr.AlterGeomFieldDefn(0, geom_field, ogr.ALTER_GEOM_FIELD_DEFN_NULLABLE_FLAG) + + # TODO evaluate use external storage not main storage for geometries + # https://blog.cleverelephant.ca/2018/09/postgis-external-storage.html + + # sync before calling StartTransaction() so we're not trying to rollback changes + # on a non-existing table + lyr.SyncToDisk() + return lyr + +def clusterLayer(lyr : ogr.Layer, + index_name : Optional[str] = None, + column_name : Optional[str] = None, + analyze : bool = True) -> bool: + """Cluster a table according to an index. If no index name is given and a column name is + given instead, then the index involving the least number of columns (including the given + column) is chosen. If neither index name or column name is given, then recluster the table + using the same index as before. An optional boolean (default value: True) indicates whether + to ANALYZE the table after clustering. See + https://www.postgresql.org/docs/current/sql-cluster.html . + Requires that the dataset driver is PostgreSQL.""" + ds = lyr.GetDataset() + if ds.GetDriver().ShortName != 'PostgreSQL': + logging.warning('clusterLayer() called on a non-PostgreSQL dataset, ignoring') + return False + + layername_esc = getEscapedTableName(lyr) + if index_name is None and column_name is not None: + # find out which indices involve lyr's column_name + with executeSQL(ds, statement='WITH indices AS (' + 'SELECT i.relname AS index, array_agg(a.attname) AS columns ' + 'FROM pg_class t, pg_class i, pg_index ix, pg_attribute a ' + 'WHERE t.oid = ix.indrelid ' + 'AND i.oid = ix.indexrelid ' + 'AND a.attrelid = t.oid ' + 'AND a.attnum = ANY(ix.indkey) ' + 'AND t.relkind = \'r\' ' + 'AND ix.indrelid = ' + escape_literal_str(layername_esc) + '::regclass ' + 'GROUP BY 1) ' + 'SELECT index, array_length(columns, 1) AS len ' + 'FROM indices ' + 'WHERE ' + escape_literal_str(column_name) + ' = ANY(columns)' + # pick the index involving the least number of columns + 'ORDER BY 2,1 LIMIT 1') as res: + defn = res.GetLayerDefn() + i = defn.GetFieldIndex('index') + row = res.GetNextFeature() + if row is not None and row.IsFieldSetAndNotNull(i): + index_name = row.GetFieldAsString(i) + + if index_name is None: + logging.warning('Layer %s has no index on column %s, cannot CLUSTER', + lyr.GetName(), column_name) + return False + + statement = 'CLUSTER ' + layername_esc + if index_name is not None: + statement += ' USING ' + escape_identifier(index_name) + executeSQL(ds, statement=statement) + + if analyze: + # "Because the planner records statistics about the ordering of tables, it is + # advisable to run ANALYZE on the newly clustered table. Otherwise, the planner + # might make poor choices of query plans." + executeSQL(ds, statement='ANALYZE ' + layername_esc) + + return True + +# pylint: disable-next=too-many-branches +def validateOutputLayer(lyr : ogr.Layer, + srs : Optional[osr.SpatialReference] = None, + options : Optional[dict[str, Any]] = None) -> bool: + """Validate the output layer against the provided SRS and creation options.""" + ok = True + + # ensure the output SRS is equivalent + if srs is not None: + srs2 = lyr.GetSpatialRef() + # cf. apps/ogr2ogr_lib.cpp + srs_options = [ + 'IGNORE_DATA_AXIS_TO_SRS_AXIS_MAPPING=YES', + 'CRITERION=EQUIVALENT' + ] + if not srs.IsSame(srs2, srs_options): + logging.warning('Output layer "%s" has SRS %s,\nexpected %s', + lyr.GetName(), + srs2.ExportToPrettyWkt(), + srs.ExportToPrettyWkt()) + ok = False + + if options is None: + return ok + + layerDefn = lyr.GetLayerDefn() + n = layerDefn.GetGeomFieldCount() + if n != 1: + if n == 0: + raise RuntimeError(f'Output layer "{lyr.GetName()}" has no geometry fields') + logging.warning('Output layer "%s" has %d != 1 geometry fields', lyr.GetName(), n) + + iGeomField = 0 + geomField = layerDefn.GetGeomFieldDefn(iGeomField) + geomType = geomField.GetType() + logging.debug('Geometry column #%d: name="%s\", type="%s", srs=%s, nullable=%s', + iGeomField, geomField.GetName(), + ogr.GeometryTypeToName(geomType), + '-' if geomField.GetSpatialRef() is None + else '"' + geomField.GetSpatialRef().GetName() + '"', + bool(geomField.IsNullable())) + if geomField.IsNullable(): + logging.warning('Geometry column #%d "%s" of output layer "%s" is nullable', + iGeomField, geomField.GetName(), lyr.GetName()) + + geomType2 = options['geometry-type'] + if geomType != geomType2: + logging.warning('Output layer "%s" has geometry type #%d (%s), expected #%d (%s)', + lyr.GetName(), + geomType, ogr.GeometryTypeToName(geomType), + geomType2, ogr.GeometryTypeToName(geomType2)) + ok = False + + fields = options.get('fields', None) + if fields is not None: + for fld in fields: + fldName = fld['Name'] + + idx = layerDefn.GetFieldIndex(fldName) + if idx < 0: + logging.warning('Output layer "%s" has no field named "%s"', + lyr.GetName(), fldName) + ok = False + continue + defn = layerDefn.GetFieldDefn(idx) + + if 'AlternativeName' in fld: + v1 = defn.GetAlternativeName() + v2 = fld['AlternativeName'] + if v1 != v2: + logging.warning('Field "%s" has AlternativeName="%s", expected "%s"', + fldName, v1, v2) + ok = False + + if 'Comment' in fld: + v1 = defn.GetComment() + v2 = fld['Comment'] + if v1 != v2: + logging.warning('Field "%s" has Comment="%s", expected "%s"', + fldName, v1, v2) + ok = False + + if 'Type' in fld: + v1 = defn.GetType() + v2 = fld['Type'] + if v1 != v2: + logging.warning('Field "%s" has Type=%d (%s), expected %d (%s)', + fldName, + v1, ogr.GetFieldTypeName(v1), + v2, ogr.GetFieldTypeName(v2)) + ok = False + + if 'SubType' in fld: + v1 = defn.GetSubType() + v2 = fld['SubType'] + if v1 != v2: + logging.warning('Field "%s" has SubType=%d (%s), expected %d (%s)', + fldName, + v1, ogr.GetFieldSubTypeName(v1), + v2, ogr.GetFieldSubTypeName(v2)) + ok = False + + if 'TZFlag' in fld: + v1 = defn.GetTZFlag() + v2 = fld['TZFlag'] + if v1 != v2: + logging.warning('Field "%s" has TZFlag=%d (%s), expected %d (%s)', + fldName, v1, formatTZFlag(v1), v2, formatTZFlag(v2)) + ok = False + + if 'Precision' in fld: + v1 = defn.GetPrecision() + v2 = fld['Precision'] + if v1 != v2: + logging.warning('Field "%s" has Precision=%d, expected %d', + fldName, v1, v2) + ok = False + + if 'Width' in fld: + v1 = defn.GetWidth() + v2 = fld['Width'] + if v1 != v2: + logging.warning('Field "%s" has Width=%d, expected %d', + fldName, v1, v2) + ok = False + + if 'Default' in fld: + v1 = defn.GetDefault() + v2 = fld['Default'] + if v1 != v2: + logging.warning('Field "%s" has Default="%s", expected "%s"', + fldName, v1, v2) + ok = False + + if 'Nullable' in fld: + v1 = bool(defn.IsNullable()) + v2 = fld['Nullable'] + if v1 != v2: + logging.warning('Field "%s" has Nullable=%s, expected %s', + fldName, v1, v2) + ok = False + + if 'Unique' in fld: + v1 = bool(defn.IsUnique()) + v2 = fld['Unique'] + if v1 != v2: + logging.warning('Field "%s" has Unique=%s, expected %s', + fldName, v1, v2) + ok = False + + return ok + +def clearLayer(lyr : ogr.Layer, identity : str = 'CONTINUE IDENTITY') -> None: + """Clear the given layer (wipe all its features)""" + n = -1 + if lyr.TestCapability(ogr.OLCFastFeatureCount): + n = lyr.GetFeatureCount(force=0) + if n == 0: + # nothing to clear, we're good + return + + ds = lyr.GetDataset() + if ds.GetDriver().ShortName == 'PostgreSQL': + # https://www.postgresql.org/docs/15/sql-truncate.html + statement = 'TRUNCATE TABLE {table} ' + identity + ' CASCADE' + op = 'Truncating' + else: + statement = 'DELETE FROM {table}' + op = 'Clearing' + logging.info('%s table %s (former feature count: %s)', op, + lyr.GetName(), str(n) if n >= 0 else 'unknown') + executeSQL(ds, statement=statement.format(table=getEscapedTableName(lyr))) + +def extractArchive(path : Path, destdir : str, + fmt : str|None = None, + patterns : list[str]|None = None, + exact_matches : list[str]|None = None) -> None: + """Extract an archive file into the given destination directory.""" + if fmt is None: + suffix = path.suffix + if suffix is None or suffix == '' or not suffix.startswith('.'): + raise RuntimeError(f'Could not infer archive format from "{path}"') + fmt = suffix.removeprefix('.') + + fmt = fmt.lower() + logging.debug('Unpacking %s archive %s into %s', fmt, path, destdir) + + if fmt == 'zip': + import zipfile # pylint: disable=import-outside-toplevel + logging.debug('Opening %s as ZipFile', path) + with zipfile.ZipFile(path, mode='r') as z: + namelist = listArchiveMembers(z.namelist(), + patterns=patterns, + exact_matches=exact_matches) + z.extractall(path=destdir, members=namelist) + else: + raise RuntimeError(f'Unknown archive format "{fmt}"') + +def listArchiveMembers(namelist : list[str], + patterns : list[str]|None = None, + exact_matches : list[str]|None = None) -> list[str]: + """List archive members matching the given parterns and/or exact matches.""" + if patterns is None and exact_matches is None: + # if neither patterns nor exact_matches are given we'll extract the entire archive + return namelist + if patterns is None: + patterns = [] + if exact_matches is None: + exact_matches = [] + + members = [] + for name in namelist: + ok = False + if name in exact_matches: + # try exact matches first + logging.debug('Listed archive member %s (exact match)', name) + members.append(name) + ok = True + continue + # if there are no exact matches, try patterns one by one in the supplied order + for pat in patterns: + if fnmatchcase(name, pat): + logging.debug('Listed archive member %s (matching pattern "%s")', name, pat) + members.append(name) + ok = True + break + if not ok: + logging.debug('Ignoring archive member %s', name) + return members + +@enum_unique +class ImportStatus(Enum): + """Return value for importSources(): success, error, or no-change.""" + IMPORT_SUCCESS = 0 + IMPORT_ERROR = 1 + IMPORT_NOCHANGE = 255 + + def __str__(self): + return self.name.removeprefix('IMPORT_') + +# pylint: disable-next=dangerous-default-value +def importSources(lyr : ogr.Layer, + sources : dict[str,Any] = {}, + cachedir : Path|None = None, + extent : ogr.Geometry|None = None, + dsoTransaction : bool = False, + lyrcache : ogr.Layer|None = None, + force : bool = False, + cluster_geometry : bool = False) -> ImportStatus: + """Clear lyr and import source layers to it.""" + + dso = lyr.GetDataset() + layername = lyr.GetName() + if dsoTransaction: + # declare a SAVEPOINT (nested transaction) within the DS-level transaction + lyrTransaction = 'SAVEPOINT ' + escape_identifier('savept_' + layername) + executeSQL(dso, lyrTransaction) + elif lyr.TestCapability(ogr.OLCTransactions): + # try to start transaction on the layer + logging.debug('Starting transaction on output layer "%s"', layername) + lyrTransaction = lyr.StartTransaction() == ogr.OGRERR_NONE + if not lyrTransaction: + logging.warning('Couldn\'t start transaction on output layer "%s"', layername) + else: + logging.warning('Unsafe update, output layer "%s" doesn\'t support transactions', + layername) + lyrTransaction = False + + rv = ImportStatus.IMPORT_NOCHANGE + now = datetime.now().astimezone() + try: + clearLayer(lyr) # TODO conditional (only if not new)? + + for source in sources: + importSource0(lyr, **source['source'], + args=source['import'], + cachedir=cachedir, + extent=extent, + callback=_importSource2) + + # force the PG driver to call EndCopy() to detect errors and trigger a + # rollback if needed + dso.FlushCache() + + if lyrcache is None: + rv = ImportStatus.IMPORT_SUCCESS + elif updateLayerCache(cache=lyrcache, + lyr=lyr, + force=force, + lyrTransaction=lyrTransaction, + last_updated=now): + rv = ImportStatus.IMPORT_SUCCESS + else: + rv = ImportStatus.IMPORT_NOCHANGE + if isinstance(lyrTransaction, bool): + # the transaction on lyr was already rolled back + lyrTransaction = False + + if (rv == ImportStatus.IMPORT_SUCCESS and cluster_geometry + and lyr.GetLayerDefn().GetGeomType() != ogr.wkbNone): + clusterLayer(lyr, column_name=lyr.GetGeometryColumn()) + + except Exception: # pylint: disable=broad-exception-caught + rv = ImportStatus.IMPORT_ERROR + if isinstance(lyrTransaction, str): + statement = 'ROLLBACK TO ' + lyrTransaction + logging.exception('Exception occured within transaction') + # don't unset lyrTransaction here as we want to RELEASE SAVEPOINT + try: + executeSQL(dso, statement=statement) + except Exception: # pylint: disable=broad-exception-caught + logging.exception('Could not execute SQL: %s', statement) + elif isinstance(lyrTransaction, bool) and lyrTransaction: + logging.exception('Exception occured within transaction on output ' + 'layer "%s": ROLLBACK', layername) + lyrTransaction = None + try: + if lyr.RollbackTransaction() != ogr.OGRERR_NONE: + logging.error('Could not rollback transaction on layer "%s"', layername) + except Exception: # pylint: disable=broad-exception-caught + logging.exception('Could not rollback transaction on layer "%s"', layername) + else: + traceback.print_exc() + + finally: + if isinstance(lyrTransaction, str): + statement = 'RELEASE ' + lyrTransaction + try: + executeSQL(dso, statement) + except Exception: # pylint: disable=broad-exception-caught + rv = ImportStatus.IMPORT_ERROR + logging.exception('Could not execute SQL: %s', statement) + elif isinstance(lyrTransaction, bool) and lyrTransaction: + try: + if lyr.CommitTransaction() != ogr.OGRERR_NONE: + rv = ImportStatus.IMPORT_ERROR + logging.error('Could not commit transaction') + except Exception: # pylint: disable=broad-exception-caught + rv = ImportStatus.IMPORT_ERROR + logging.exception('Could not commit transaction on layer "%s"', layername) + return rv + +# pylint: disable-next=dangerous-default-value +def importSource0(lyr : ogr.Layer|None = None, + path : str = '/nonexistent', + unar : dict[str,Any]|None = None, + args : dict[str,Any] = {}, + cachedir : Path|None = None, + extent : ogr.Geometry|None = None, + callback : Callable[[ogr.Layer|None, str, dict[str,Any], Path|None, + ogr.Geometry|None], None]|None = None) -> None: + """Import a source layer""" + if unar is None: + return callback(lyr, path, args=args, basedir=cachedir, extent=extent) + + ds_srcpath = Path(args['path']) + if ds_srcpath.is_absolute(): + # treat absolute paths as relative to the archive root + logging.warning('%s is absolute, removing leading anchor', ds_srcpath) + ds_srcpath = ds_srcpath.relative_to(ds_srcpath.anchor) + + ds_srcpath = str(ds_srcpath) + with tempfile.TemporaryDirectory() as tmpdir: + logging.debug('Created temporary directory %s', tmpdir) + extractArchive(Path(path) if cachedir is None else cachedir.joinpath(path), + tmpdir, + fmt=unar.get('format', None), + patterns=unar.get('patterns', None), + exact_matches=[ds_srcpath]) + return callback(lyr, ds_srcpath, args=args, basedir=Path(tmpdir), extent=extent) + +def setFieldMapValue(fld : ogr.FieldDefn, + idx : int, + val : None|int|str|bytes|float) -> None|int|str|bytes|float: + """Validate field value mapping.""" + if val is None: + if not fld.IsNullable(): + logging.warning('Field "%s" is not NULLable but remaps NULL', fld.GetName()) + return None + + fldType = fld.GetType() + if fldType in (ogr.OFTInteger, ogr.OFTInteger64): + if isinstance(val, int): + return val + elif fldType == ogr.OFTString: + if isinstance(val, str): + return val + elif fldType == ogr.OFTBinary: + if isinstance(val, bytes): + return val + elif fldType == ogr.OFTReal: + if isinstance(val, int): + return float(val) + if isinstance(val, float): + return val + + raise RuntimeError(f'Field "{fld.GetName()}" mapping #{idx} has incompatible type ' + f'for {ogr.GetFieldTypeName(fldType)}') + +def _importSource2(lyr_dst : ogr.Layer, path : str, args : dict[str,Any], + basedir : Path|None, extent : ogr.Geometry|None) -> None: + """Import a source layer (already extracted) + This is more or less like ogr2ogr/GDALVectorTranslate() but we roll + out our own (slower) version because GDALVectorTranslate() insists in + calling StartTransaction() https://github.com/OSGeo/gdal/issues/3403 + while we want a single transaction for the entire desination layer, + including truncation, source imports, and metadata changes.""" + kwargs, _ = gdalSetOpenExArgs(args, flags=GDAL_OF_VECTOR|GDAL_OF_READONLY|GDAL_OF_VERBOSE_ERROR) + path2 = path if basedir is None else str(basedir.joinpath(path)) + + logging.debug('OpenEx(%s, %s)', path2, str(kwargs)) + ds = gdal.OpenEx(path2, **kwargs) + if ds is None: + raise RuntimeError(f'Could not open {path2}') + + layername = args.get('layername', None) + if layername is None: + idx = 0 + lyr = ds.GetLayerByIndex(idx) + msg = '#' + str(idx) + if lyr is not None: + layername = lyr.GetName() + msg += ' ("' + layername + '")' + else: + lyr = ds.GetLayerByName(layername) + msg = '"' + layername + '"' + if lyr is None: + raise RuntimeError(f'Could not get requested layer {msg} from {path2}') + + logging.info('Importing layer %s from "%s"', msg, path) + importLayer(lyr_dst, lyr, args=args, extent=extent) + +# pylint: disable-next=too-many-branches, too-many-statements, dangerous-default-value +def importLayer(lyr_dst : ogr.Layer, lyr : ogr.Layer, + args : dict[str,Any] = {}, + extent : Optional[ogr.Geometry] = None) -> None: + """Import a source layer (already opened).""" + layername = lyr.GetName() + srs = lyr.GetSpatialRef() + if srs is None: + raise RuntimeError(f'Source layer {layername} has no SRS') + + srs_dst = lyr_dst.GetSpatialRef() + if srs_dst is None: + logging.warning('Destination has no SRS, skipping coordinate transformation') + ct = None + elif srs_dst.IsSame(srs): + logging.debug('Both source and destination have the same SRS (%s), ' + 'skipping coordinate transformation', + srs_dst.GetName()) + ct = None + else: + # TODO Use GetSupportedSRSList() and SetActiveSRS() with GDAL ≥3.7.0 + # when possible, see apps/ogr2ogr_lib.cpp + # pylint: disable=duplicate-code + logging.debug('Creating transforming from source SRS (%s) to destination SRS (%s)', + srs.GetName(), srs_dst.GetName()) + ct = osr.CoordinateTransformation(srs, srs_dst) + if ct is None: + raise RuntimeError(f'Could not create transformation from source SRS ({srs.GetName()}) ' + f'to destination SRS ({srs_dst.GetName()})') + + defn = lyr.GetLayerDefn() + geomFieldCount = defn.GetGeomFieldCount() + if geomFieldCount != 1: + # TODO Add support for multiple geometry fields (also in the fingerprinting logic below) + logging.warning('Source layer "%s" has %d != 1 geometry fields', layername, geomFieldCount) + + fieldCount = defn.GetFieldCount() + fieldMap = [-1] * fieldCount + fields = args['field-map'] + fieldSet = set() + + canIgnoreFields = lyr.TestCapability(ogr.OLCIgnoreFields) + ignoredFieldNames = [] + for i in range(fieldCount): + fld = defn.GetFieldDefn(i) + fldName = fld.GetName() + fieldMap[i] = v = fields.get(fldName, -1) + fieldSet.add(fldName) + + if v < 0 and canIgnoreFields: + # call SetIgnored() on unwanted source fields + logging.debug('Set Ignored=True on output field "%s"', fldName) + fld.SetIgnored(True) + ignoredFieldNames.append(fldName) + + count0 = -1 + if lyr.TestCapability(ogr.OLCFastFeatureCount): + count0 = lyr.GetFeatureCount(force=0) + + if count0 == 0 and len(fieldSet) == 0: + # skip the below warning in some cases (e.g., GeoJSON source) + logging.info('Source layer "%s" has no fields nor features, skipping', layername) + return + + logging.debug('Field map: %s', str(fieldMap)) + for fld in fields: + if not fld in fieldSet: + logging.warning('Source layer "%s" has no field named "%s", ignoring', layername, fld) + + count1 = -1 + if args.get('spatial-filter', True) and extent is not None: + spatialFilter = getSpatialFilterFromGeometry(extent, srs) + logging.debug('Setting spatial filter to %s', spatialFilter.ExportToWkt()) + lyr.SetSpatialFilter(spatialFilter) + + if lyr.TestCapability(ogr.OLCFastFeatureCount): + count1 = lyr.GetFeatureCount(force=0) + + if count0 >= 0: + if count1 >= 0: + logging.info('Source layer "%s" has %d features (%d of which intersecting extent)', + layername, count0, count1) + else: + logging.info('Source layer "%s" has %d features', layername, count0) + + logging.info('Ignored fields from source layer: %s', + '-' if len(ignoredFieldNames) == 0 else ', '.join(ignoredFieldNames)) + + # build a list of triplets (field index, replacement_for_null, [(from_value, to_value), …]) + valueMap = [] + for fldName, rules in args.get('value-map', {}).items(): + i = defn.GetFieldIndex(fldName) + if i < 0: + raise RuntimeError(f'Source layer "{layername}" has no field named "{fldName}"') + if fieldMap[i] < 0: + logging.warning('Ignored source field "%s" has value map', fldName) + continue + + hasNullReplacement = False + nullReplacement = None + mapping = [] + fld = defn.GetFieldDefn(i) + for idx, (rFrom, rTo) in enumerate(rules): + # use fld for both 'from' and 'to' (the types must match, casting is not + # allowed in the mapping) + if rFrom is None: + if hasNullReplacement: + logging.warning('Field "%s" has duplicate NULL replacement', + fld.GetName()) + else: + setFieldMapValue(fld, idx, None) # validate NULL + rTo = setFieldMapValue(fld, idx, rTo) + hasNullReplacement = True + nullReplacement = rTo + elif isinstance(rFrom, re.Pattern): + # validate but keep the rFrom regex + setFieldMapValue(fld, idx, str(rFrom)) + rTo = setFieldMapValue(fld, idx, rTo) + mapping.append( (rFrom, rTo, 1) ) + else: + rFrom = setFieldMapValue(fld, idx, rFrom) + rTo = setFieldMapValue(fld, idx, rTo) + mapping.append( (rFrom, rTo, 0) ) + + if nullReplacement is not None or len(mapping) > 0: + valueMap.append( (i, nullReplacement, mapping) ) + + if args.get('rstrip-strings', False): + stringFieldsIdx = [ i for i in range(fieldCount) + if defn.GetFieldDefn(i).GetType() == ogr.OFTString and + fieldMap[i] >= 0 ] + logging.debug('Source field indices to rstrip: %s', str(stringFieldsIdx)) + bStringFields = len(stringFieldsIdx) > 0 + else: + bStringFields = False + + bValueMap = len(valueMap) > 0 + defn = None + + defn_dst = lyr_dst.GetLayerDefn() + eGType_dst = defn_dst.GetGeomType() + eGType_dst_HasZ = ogr.GT_HasZ(eGType_dst) + eGType_dst_HasM = ogr.GT_HasM(eGType_dst) + dGeomIsUnknown = ogr.GT_Flatten(eGType_dst) == ogr.wkbUnknown + + if bValueMap: + valueMapCounts = [0] * fieldCount + + featureCount = 0 + mismatch = {} + feature = lyr.GetNextFeature() + while feature is not None: + if bStringFields: + for i in stringFieldsIdx: + if feature.IsFieldSetAndNotNull(i): + v = feature.GetField(i) + feature.SetField(i, v.rstrip()) + + if bValueMap: + for i, nullReplacement, mapping in valueMap: + if not feature.IsFieldSet(i): + continue + if feature.IsFieldNull(i): + if nullReplacement is not None: + # replace NULL with non-NULL value + feature.SetField(i, nullReplacement) + valueMapCounts[i] += 1 + continue + + v = feature.GetField(i) + for rFrom, rTo, rType in mapping: + if rType == 0: + # literal + if v != rFrom: + continue + elif rType == 1: + # regex + m = rFrom.fullmatch(v) + if m is None: + continue + if rTo is not None: + rTo = rTo.format(*m.groups()) + else: + raise RuntimeError(str(rType)) + + if rTo is None: + # replace non-NULL value with NULL + feature.SetFieldNull(i) + else: + # replace non-NULL value with non-NULL value + feature.SetField(i, rTo) + valueMapCounts[i] += 1 + break + + feature2 = ogr.Feature(defn_dst) + feature2.SetFromWithMap(feature, False, fieldMap) + + geom = feature2.GetGeometryRef() + if geom is None: + if eGType_dst != ogr.wkbNone: + logging.warning('Source feature #%d has no geometry, trying to transfer anyway', + feature.GetFID()) + else: + if ct is not None and geom.Transform(ct) != ogr.OGRERR_NONE: + raise RuntimeError('Could not apply coordinate transformation') + + eGType = geom.GetGeometryType() + if eGType != eGType_dst and not dGeomIsUnknown: + # Promote to multi, cf. apps/ogr2ogr_lib.cpp:ConvertType() + eGType2 = eGType + if eGType in (ogr.wkbTriangle, ogr.wkbTIN, ogr.wkbPolyhedralSurface): + eGType2 = ogr.wkbMultiPolygon + elif not ogr.GT_IsSubClassOf(eGType, ogr.wkbGeometryCollection): + eGType2 = ogr.GT_GetCollection(eGType) + + eGType2 = ogr.GT_SetModifier(eGType2, eGType_dst_HasZ, eGType_dst_HasM) + if eGType2 == eGType_dst: + mismatch[eGType] = mismatch.get(eGType, 0) + 1 + geom = ogr.ForceTo(geom, eGType_dst) + # TODO call MakeValid()? + else: + raise RuntimeError(f'Conversion from {ogr.GeometryTypeToName(eGType)} ' + f'to {ogr.GeometryTypeToName(eGType_dst)} not implemented') + feature2.SetGeometryDirectly(geom) + + if lyr_dst.CreateFeature(feature2) != ogr.OGRERR_NONE: + raise RuntimeError(f'Could not transfer source feature #{feature.GetFID()}') + + featureCount += 1 + feature = lyr.GetNextFeature() + + if bValueMap: + valueMapCounts = [ (lyr.GetLayerDefn().GetFieldDefn(i).GetName(), k) + for i,k in enumerate(valueMapCounts) if k > 0 ] + + lyr = None + logging.info('Imported %d features from source layer "%s"', featureCount, layername) + + if bValueMap: + if len(valueMapCounts) > 0: + valueMapCounts = ', '.join([ str(k) + '× "' + n + '"' for n,k in valueMapCounts ]) + else: + valueMapCounts = '-' + logging.info('Field substitutions: %s', valueMapCounts) + + if len(mismatch) > 0: + mismatches = [ str(n) + '× ' + ogr.GeometryTypeToName(t) + for t,n in sorted(mismatch.items(), key=lambda x: x[1]) ] + logging.info('Forced conversion to %s: %s', + ogr.GeometryTypeToName(eGType_dst), ', '.join(mismatches)) + +def listFieldsOrderBy(defn : ogr.FeatureDefn, + unique : bool|None = None, + nullable : bool|None = None) -> Iterator[str]: + """Return an iterator of column names suitable for ORDER BY.""" + fields_str = {} + for i in range(defn.GetFieldCount()): + fld = defn.GetFieldDefn(i) + if (fld.IsIgnored() or + # if 'unique' or 'unable' is not None then skip the field + # unless the boolean matches .IsUnique() resp. .IsNullable() + not (unique is None or fld.IsUnique() == unique) or + not (nullable is None or fld.IsNullable() == nullable)): + continue + if fld.GetType() in (ogr.OFTInteger, ogr.OFTInteger64): + # list integers first + yield fld.GetName() + elif fld.GetType() == ogr.OFTString: + w = fld.GetWidth() + if 0 < w < 256: + # only consider short-ish strings + fields_str[fld.GetName()] = w + # order string columns by width + for c,_ in sorted(fields_str.items(), key=lambda x:x[1]): + yield c + +# pylint: disable-next=too-many-branches, too-many-statements +def updateLayerCache(lyr : ogr.Layer, cache : ogr.Layer, + last_updated : datetime, + lyrTransaction : str|bool|None = None, + force : bool = False) -> bool: + """Update attributes in the layer cache for the given layer name. + Return a boolean indicating whether changes to layername were *not* + rolled back (hence might still be outstanding in the transaction).""" + layername = lyr.GetName() + + dgst = sha256() + defn = lyr.GetLayerDefn() + fields = [] + for i in range(defn.GetFieldCount()): + fields.append('t.' + escape_identifier(defn.GetFieldDefn(i).GetName())) + if len(fields) == 0: + fields = ['0 AS hash_properties'] + else: + fields = [ 'hash_record_extended(ROW(' + ','.join(fields) + '),0) AS hash_properties' ] + + fidColumn = lyr.GetFIDColumn() + if fidColumn is None or fidColumn == '': + raise RuntimeError(f'Couldn\'t find FID column for "{layername}"') + # defn.GetGeomFieldCount() != 1 is not supported and yields a warning in _importSource2() + geometryColumn = lyr.GetGeometryColumn() + if geometryColumn is None or geometryColumn == '': + raise RuntimeError(f'Couldn\'t find geometry column for "{layername}"') + + fields.append('sha256(COALESCE(' + + 'ST_AsEWKB(t.' + escape_identifier(geometryColumn) + '),' + + '\'\')) AS hash_geom') + if len(fields) == 0: + raise RuntimeError('Empty field list in SELECT') + query = 'SELECT ' + ','.join(fields) + ' FROM ' + getEscapedTableName(lyr) + ' t' + + sort_by = next(listFieldsOrderBy(defn, unique=True, nullable=False), None) + if sort_by is not None: + sort_by = [ sort_by ] + else: + count = lyr.GetFeatureCount(force=0) + if count is None or count < 0 or count > 5000: + logging.warning('Layer "%s" has many (%s) features but no UNIQUE NOT NULL constraint, ' + 'sorting might be unstable and slow', layername, + str(count) if (count is not None and count >= 0) else 'N/A') + sort_by = list(listFieldsOrderBy(defn)) + [ geometryColumn, fidColumn ] + query += ' ORDER BY ' + ','.join(['t.' + escape_identifier(c) for c in sort_by]) + + struct_dgst : Final = struct.Struct('@qq').pack + ds = lyr.GetDataset() + with executeSQL(ds, query) as lyr2: + defn2 = lyr2.GetLayerDefn() + assert defn2.GetFieldDefn(0).GetName() == 'hash_properties' + assert defn2.GetFieldDefn(1).GetName() == 'hash_geom' + feature = lyr2.GetNextFeature() + while feature is not None: + dgst.update(struct_dgst(feature.GetFID(), feature.GetFieldAsInteger64(0))) + dgst.update(feature.GetFieldAsBinary(1)) + feature = lyr2.GetNextFeature() + fingerprint = dgst.digest() + + attributeFilter = 'layername = ' + escape_literal_str(layername) + logging.debug('SetAttributeFilter("%s", "%s")', cache.GetName(), attributeFilter) + cache.SetAttributeFilter(attributeFilter) + + feature = cache.GetNextFeature() + if feature is None: + # not in cache + logging.debug('Creating new feature in layer cache for %s', attributeFilter) + update = False + feature = ogr.Feature(cache.GetLayerDefn()) + feature.SetFieldString(0, layername) + fingerprint_old = None + else: + logging.debug('Updating existing feature in layer cache for %s', attributeFilter) + update = True + fingerprint_old = feature.GetFieldAsBinary(2) if feature.IsFieldSetAndNotNull(2) else None + assert cache.GetNextFeature() is None + + if last_updated.tzinfo == UTC: + tzFlag = ogr.TZFLAG_UTC + else: + td = last_updated.utcoffset() + # 15min increments/decrements per unit above/below UTC, cf. + # https://gdal.org/en/stable/api/vector_c_api.html#c.OGR_TZFLAG_UTC + tzFlag = td.days * 96 + td.seconds // 900 + if timedelta(seconds=tzFlag*900) != td or abs(tzFlag) > 56: # max ±14:00 + raise RuntimeError(f'Invalid UTC offset {td}') + tzFlag += ogr.TZFLAG_UTC + + feature.SetField(1, last_updated.year, + last_updated.month, + last_updated.day, + last_updated.hour, + last_updated.minute, + float(last_updated.second) + float(last_updated.microsecond)/1000000., + tzFlag) + if fingerprint is None: + feature.SetFieldNull(2) + else: + # https://lists.osgeo.org/pipermail/gdal-dev/2020-December/053170.html + feature.SetFieldBinaryFromHexString(2, fingerprint.hex()) + + ret = True + if update: + if (fingerprint is None or fingerprint_old is None or fingerprint != fingerprint_old): + logging.info('Updated layer "%s" has new fingerprint %s', layername, + fingerprint.hex()[:8] if fingerprint is not None else 'N/A') + elif force: + logging.info('Updated layer "%s" has identical fingerprint %s', + layername, fingerprint.hex()[:8]) + else: + # no change: rollback (*before* updating the cache) if possible to retain FID values + if isinstance(lyrTransaction, str): + logging.info('Updated layer "%s" has identical fingerprint %s, rolling back', + layername, fingerprint.hex()[:8]) + try: + executeSQL(ds, 'ROLLBACK TO ' + lyrTransaction) + except Exception: # pylint: disable=broad-exception-caught + logging.exception('Could not execute SQL: %s', query) + else: + ret = False + elif isinstance(lyrTransaction, bool) and lyrTransaction: + logging.info('Updated layer "%s" has identical fingerprint %s, rolling back', + layername, fingerprint.hex()[:8]) + try: + if lyr.RollbackTransaction() == ogr.OGRERR_NONE: + ret = False + else: + logging.error('Could not rollback transaction on layer "%s"', + layername) + except Exception: # pylint: disable=broad-exception-caught + logging.exception('Could not rollback transaction on layer "%s"', + layername) + else: + logging.info('Updated layer "%s" has identical fingerprint %s', + layername, fingerprint.hex()[:8]) + + if cache.UpdateFeature(feature, [1,2], [], False) != ogr.OGRERR_NONE: + raise RuntimeError('Could not update feature in layer cache') + else: + if cache.CreateFeature(feature) != ogr.OGRERR_NONE: + raise RuntimeError('Could not create new feature in layer cache') + + # force the PG driver to call EndCopy() to detect errors and trigger a + # rollback if needed + ds.FlushCache() + return ret |
