From 2fa39019cd4bbe0c221b084a9bd17698f8ffd767 Mon Sep 17 00:00:00 2001 From: Guilhem Moulin Date: Fri, 6 Mar 2026 09:00:51 +0100 Subject: Add wrapper for gdal.Dataset.ExecuteSQL(). --- common_gdal.py | 10 ++++++++ export_mvt.py | 12 ++++++--- import_source.py | 77 ++++++++++++++++++++++++-------------------------------- 3 files changed, 51 insertions(+), 48 deletions(-) diff --git a/common_gdal.py b/common_gdal.py index 401e0a5..b5570eb 100644 --- a/common_gdal.py +++ b/common_gdal.py @@ -371,3 +371,13 @@ def getEscapedTableName(lyr : ogr.Layer, extract_schema_from_layer_name : bool = if lyr.GetDataset().GetDriver().ShortName == 'PostgreSQL': return getEscapedTableNamePG(layername, extract_schema_from_layer_name) return escape_identifier(layername) + +def executeSQL(ds : gdal.Dataset, statement : str, + spatialFilter : Optional[ogr.Geometry] = None) -> ogr.Layer|None: + """Wrapper for gdal.Dataset.ExecuteSQL(). + https://gdal.org/en/stable/api/python/raster_api.html#osgeo.gdal.Dataset.ExecuteSQL""" + msg = statement + if spatialFilter is not None: + msg += ', spatialFilter=' + spatialFilter.ExportToWkt() + logging.debug('ExecuteSQL(%s)', msg) + return ds.ExecuteSQL(statement, spatialFilter=spatialFilter) diff --git a/export_mvt.py b/export_mvt.py index b893d79..b56fba1 100644 --- a/export_mvt.py +++ b/export_mvt.py @@ -35,7 +35,13 @@ import brotli from osgeo import gdal, ogr, osr from common import BadConfiguration, escape_identifier, format_bytes, format_time -from common_gdal import getExtent, getSRS, getSpatialFilterFromGeometry, getEscapedTableName +from common_gdal import ( + getExtent, + getSRS, + getSpatialFilterFromGeometry, + getEscapedTableName, + executeSQL, +) from rename_exchange import rename_exchange def parseTilingScheme(scheme : list[Any]) -> tuple[osr.SpatialReference, ogr.Geometry|None]: @@ -184,9 +190,7 @@ def exportSourceLayer(lyr_src : ogr.Layer, query += ' WHERE ' + cond.strip() ds_src = lyr_src.GetDataset() - logging.debug('ExecuteSQL(%s%s)', query, - '' if spatialFilter is None else ', spatialFilter=' + spatialFilter.ExportToWkt()) - with ds_src.ExecuteSQL(query, spatialFilter=spatialFilter) as lyr_src2: + with executeSQL(ds_src, query, spatialFilter=spatialFilter) as lyr_src2: count1 = -1 if lyr_src2.TestCapability(ogr.OLCFastFeatureCount): count1 = lyr_src2.GetFeatureCount(force=0) diff --git a/import_source.py b/import_source.py index 650c31d..bf30430 100644 --- a/import_source.py +++ b/import_source.py @@ -50,6 +50,7 @@ from common_gdal import ( formatTZFlag, getSpatialFilterFromGeometry, getEscapedTableName, + executeSQL, ) def openOutputDS(def_dict : dict[str, Any]) -> gdal.Dataset: @@ -253,23 +254,21 @@ def clusterLayer(lyr : ogr.Layer, layername_esc = getEscapedTableName(lyr) if index_name is None and column_name is not None: # find out which indices involve lyr's column_name - query = 'WITH indices AS (' - query += 'SELECT i.relname AS index, array_agg(a.attname) AS columns ' - query += 'FROM pg_class t, pg_class i, pg_index ix, pg_attribute a ' - query += 'WHERE t.oid = ix.indrelid ' - query += 'AND i.oid = ix.indexrelid ' - query += 'AND a.attrelid = t.oid ' - query += 'AND a.attnum = ANY(ix.indkey) ' - query += 'AND t.relkind = \'r\' ' - query += 'AND ix.indrelid = ' + escape_literal_str(layername_esc) + '::regclass ' - query += 'GROUP BY 1) ' - query += 'SELECT index, array_length(columns, 1) AS len ' - query += 'FROM indices ' - query += 'WHERE ' + escape_literal_str(column_name) + ' = ANY(columns)' - query += 'ORDER BY 2,1 LIMIT 1' # pick the index involving the least number of columns - - logging.debug('ExecuteSQL(%s)', query) - with ds.ExecuteSQL(query) as res: + with executeSQL(ds, statement='WITH indices AS (' + 'SELECT i.relname AS index, array_agg(a.attname) AS columns ' + 'FROM pg_class t, pg_class i, pg_index ix, pg_attribute a ' + 'WHERE t.oid = ix.indrelid ' + 'AND i.oid = ix.indexrelid ' + 'AND a.attrelid = t.oid ' + 'AND a.attnum = ANY(ix.indkey) ' + 'AND t.relkind = \'r\' ' + 'AND ix.indrelid = ' + escape_literal_str(layername_esc) + '::regclass ' + 'GROUP BY 1) ' + 'SELECT index, array_length(columns, 1) AS len ' + 'FROM indices ' + 'WHERE ' + escape_literal_str(column_name) + ' = ANY(columns)' + # pick the index involving the least number of columns + 'ORDER BY 2,1 LIMIT 1') as res: defn = res.GetLayerDefn() i = defn.GetFieldIndex('index') row = res.GetNextFeature() @@ -281,19 +280,16 @@ def clusterLayer(lyr : ogr.Layer, lyr.GetName(), column_name) return False - query = 'CLUSTER ' + layername_esc + statement = 'CLUSTER ' + layername_esc if index_name is not None: - query += ' USING ' + escape_identifier(index_name) - logging.debug('ExecuteSQL(%s)', query) - ds.ExecuteSQL(query) + statement += ' USING ' + escape_identifier(index_name) + executeSQL(ds, statement=statement) if analyze: # "Because the planner records statistics about the ordering of tables, it is # advisable to run ANALYZE on the newly clustered table. Otherwise, the planner # might make poor choices of query plans." - query = 'ANALYZE ' + layername_esc - logging.debug('ExecuteSQL(%s)', query) - ds.ExecuteSQL(query) + executeSQL(ds, statement='ANALYZE ' + layername_esc) return True @@ -461,16 +457,14 @@ def clearLayer(lyr : ogr.Layer) -> None: ds = lyr.GetDataset() if ds.GetDriver().ShortName == 'PostgreSQL': # https://www.postgresql.org/docs/15/sql-truncate.html - query = 'TRUNCATE TABLE {table} CONTINUE IDENTITY CASCADE' + statement = 'TRUNCATE TABLE {table} CONTINUE IDENTITY CASCADE' op = 'Truncating' else: - query = 'DELETE FROM {table}' + statement = 'DELETE FROM {table}' op = 'Clearing' logging.info('%s table %s (former feature count: %s)', op, lyr.GetName(), str(n) if n >= 0 else 'unknown') - query = query.format(table=getEscapedTableName(lyr)) - logging.debug('ExecuteSQL(%s)', query) - ds.ExecuteSQL(query) + executeSQL(ds, statement=statement.format(table=getEscapedTableName(lyr))) def extractArchive(path : Path, destdir : str, fmt : str|None = None, @@ -555,8 +549,7 @@ def importSources(lyr : ogr.Layer, if dsoTransaction: # declare a SAVEPOINT (nested transaction) within the DS-level transaction lyrTransaction = 'SAVEPOINT ' + escape_identifier('savept_' + layername) - logging.debug('ExecuteSQL(%s)', lyrTransaction) - dso.ExecuteSQL(lyrTransaction) + executeSQL(dso, lyrTransaction) elif lyr.TestCapability(ogr.OLCTransactions): # try to start transaction on the layer logging.debug('Starting transaction on output layer "%s"', layername) @@ -605,13 +598,13 @@ def importSources(lyr : ogr.Layer, except Exception: # pylint: disable=broad-exception-caught rv = ImportStatus.IMPORT_ERROR if isinstance(lyrTransaction, str): - query = 'ROLLBACK TO ' + lyrTransaction - logging.exception('Exception occured within transaction. ExecuteSQL(%s)', query) + statement = 'ROLLBACK TO ' + lyrTransaction + logging.exception('Exception occured within transaction') # don't unset lyrTransaction here as we want to RELEASE SAVEPOINT try: - dso.ExecuteSQL(query) + executeSQL(dso, statement=statement) except Exception: # pylint: disable=broad-exception-caught - logging.exception('Could not execute SQL: %s', query) + logging.exception('Could not execute SQL: %s', statement) elif isinstance(lyrTransaction, bool) and lyrTransaction: logging.exception('Exception occured within transaction on output ' 'layer "%s": ROLLBACK', layername) @@ -626,13 +619,12 @@ def importSources(lyr : ogr.Layer, finally: if isinstance(lyrTransaction, str): - query = 'RELEASE ' + lyrTransaction - logging.debug('ExecuteSQL(%s)', query) + statement = 'RELEASE ' + lyrTransaction try: - dso.ExecuteSQL(query) + executeSQL(dso, statement) except Exception: # pylint: disable=broad-exception-caught rv = ImportStatus.IMPORT_ERROR - logging.exception('Could not execute SQL: %s', query) + logging.exception('Could not execute SQL: %s', statement) elif isinstance(lyrTransaction, bool) and lyrTransaction: try: if lyr.CommitTransaction() != ogr.OGRERR_NONE: @@ -1051,9 +1043,8 @@ def updateLayerCache(lyr : ogr.Layer, cache : ogr.Layer, query += ' ORDER BY ' + ','.join(['t.' + escape_identifier(c) for c in sort_by]) struct_dgst : Final = struct.Struct('@qq').pack - logging.debug('ExecuteSQL(%s)', query) ds = lyr.GetDataset() - with ds.ExecuteSQL(query) as lyr2: + with executeSQL(ds, query) as lyr2: defn2 = lyr2.GetLayerDefn() assert defn2.GetFieldDefn(0).GetName() == 'hash_properties' assert defn2.GetFieldDefn(1).GetName() == 'hash_geom' @@ -1119,10 +1110,8 @@ def updateLayerCache(lyr : ogr.Layer, cache : ogr.Layer, if isinstance(lyrTransaction, str): logging.info('Updated layer "%s" has identical fingerprint %s, rolling back', layername, fingerprint.hex()[:8]) - query = 'ROLLBACK TO ' + lyrTransaction - logging.debug('ExecuteSQL(%s)', query) try: - ds.ExecuteSQL(query) + executeSQL(ds, 'ROLLBACK TO ' + lyrTransaction) except Exception: # pylint: disable=broad-exception-caught logging.exception('Could not execute SQL: %s', query) else: -- cgit v1.2.3