diff options
| author | Guilhem Moulin <guilhem@fripost.org> | 2026-03-06 09:00:51 +0100 |
|---|---|---|
| committer | Guilhem Moulin <guilhem@fripost.org> | 2026-03-06 09:37:51 +0100 |
| commit | 2fa39019cd4bbe0c221b084a9bd17698f8ffd767 (patch) | |
| tree | d63e1b73434a6a239251907b86680cf30f4f6af8 /import_source.py | |
| parent | f2393202a5343dcaeffdbec518aa241573185335 (diff) | |
Add wrapper for gdal.Dataset.ExecuteSQL().
Diffstat (limited to 'import_source.py')
| -rw-r--r-- | import_source.py | 77 |
1 files changed, 33 insertions, 44 deletions
diff --git a/import_source.py b/import_source.py index 650c31d..bf30430 100644 --- a/import_source.py +++ b/import_source.py @@ -50,6 +50,7 @@ from common_gdal import ( formatTZFlag, getSpatialFilterFromGeometry, getEscapedTableName, + executeSQL, ) def openOutputDS(def_dict : dict[str, Any]) -> gdal.Dataset: @@ -253,23 +254,21 @@ def clusterLayer(lyr : ogr.Layer, layername_esc = getEscapedTableName(lyr) if index_name is None and column_name is not None: # find out which indices involve lyr's column_name - query = 'WITH indices AS (' - query += 'SELECT i.relname AS index, array_agg(a.attname) AS columns ' - query += 'FROM pg_class t, pg_class i, pg_index ix, pg_attribute a ' - query += 'WHERE t.oid = ix.indrelid ' - query += 'AND i.oid = ix.indexrelid ' - query += 'AND a.attrelid = t.oid ' - query += 'AND a.attnum = ANY(ix.indkey) ' - query += 'AND t.relkind = \'r\' ' - query += 'AND ix.indrelid = ' + escape_literal_str(layername_esc) + '::regclass ' - query += 'GROUP BY 1) ' - query += 'SELECT index, array_length(columns, 1) AS len ' - query += 'FROM indices ' - query += 'WHERE ' + escape_literal_str(column_name) + ' = ANY(columns)' - query += 'ORDER BY 2,1 LIMIT 1' # pick the index involving the least number of columns - - logging.debug('ExecuteSQL(%s)', query) - with ds.ExecuteSQL(query) as res: + with executeSQL(ds, statement='WITH indices AS (' + 'SELECT i.relname AS index, array_agg(a.attname) AS columns ' + 'FROM pg_class t, pg_class i, pg_index ix, pg_attribute a ' + 'WHERE t.oid = ix.indrelid ' + 'AND i.oid = ix.indexrelid ' + 'AND a.attrelid = t.oid ' + 'AND a.attnum = ANY(ix.indkey) ' + 'AND t.relkind = \'r\' ' + 'AND ix.indrelid = ' + escape_literal_str(layername_esc) + '::regclass ' + 'GROUP BY 1) ' + 'SELECT index, array_length(columns, 1) AS len ' + 'FROM indices ' + 'WHERE ' + escape_literal_str(column_name) + ' = ANY(columns)' + # pick the index involving the least number of columns + 'ORDER BY 2,1 LIMIT 1') as res: defn = res.GetLayerDefn() i = defn.GetFieldIndex('index') row = res.GetNextFeature() @@ -281,19 +280,16 @@ def clusterLayer(lyr : ogr.Layer, lyr.GetName(), column_name) return False - query = 'CLUSTER ' + layername_esc + statement = 'CLUSTER ' + layername_esc if index_name is not None: - query += ' USING ' + escape_identifier(index_name) - logging.debug('ExecuteSQL(%s)', query) - ds.ExecuteSQL(query) + statement += ' USING ' + escape_identifier(index_name) + executeSQL(ds, statement=statement) if analyze: # "Because the planner records statistics about the ordering of tables, it is # advisable to run ANALYZE on the newly clustered table. Otherwise, the planner # might make poor choices of query plans." - query = 'ANALYZE ' + layername_esc - logging.debug('ExecuteSQL(%s)', query) - ds.ExecuteSQL(query) + executeSQL(ds, statement='ANALYZE ' + layername_esc) return True @@ -461,16 +457,14 @@ def clearLayer(lyr : ogr.Layer) -> None: ds = lyr.GetDataset() if ds.GetDriver().ShortName == 'PostgreSQL': # https://www.postgresql.org/docs/15/sql-truncate.html - query = 'TRUNCATE TABLE {table} CONTINUE IDENTITY CASCADE' + statement = 'TRUNCATE TABLE {table} CONTINUE IDENTITY CASCADE' op = 'Truncating' else: - query = 'DELETE FROM {table}' + statement = 'DELETE FROM {table}' op = 'Clearing' logging.info('%s table %s (former feature count: %s)', op, lyr.GetName(), str(n) if n >= 0 else 'unknown') - query = query.format(table=getEscapedTableName(lyr)) - logging.debug('ExecuteSQL(%s)', query) - ds.ExecuteSQL(query) + executeSQL(ds, statement=statement.format(table=getEscapedTableName(lyr))) def extractArchive(path : Path, destdir : str, fmt : str|None = None, @@ -555,8 +549,7 @@ def importSources(lyr : ogr.Layer, if dsoTransaction: # declare a SAVEPOINT (nested transaction) within the DS-level transaction lyrTransaction = 'SAVEPOINT ' + escape_identifier('savept_' + layername) - logging.debug('ExecuteSQL(%s)', lyrTransaction) - dso.ExecuteSQL(lyrTransaction) + executeSQL(dso, lyrTransaction) elif lyr.TestCapability(ogr.OLCTransactions): # try to start transaction on the layer logging.debug('Starting transaction on output layer "%s"', layername) @@ -605,13 +598,13 @@ def importSources(lyr : ogr.Layer, except Exception: # pylint: disable=broad-exception-caught rv = ImportStatus.IMPORT_ERROR if isinstance(lyrTransaction, str): - query = 'ROLLBACK TO ' + lyrTransaction - logging.exception('Exception occured within transaction. ExecuteSQL(%s)', query) + statement = 'ROLLBACK TO ' + lyrTransaction + logging.exception('Exception occured within transaction') # don't unset lyrTransaction here as we want to RELEASE SAVEPOINT try: - dso.ExecuteSQL(query) + executeSQL(dso, statement=statement) except Exception: # pylint: disable=broad-exception-caught - logging.exception('Could not execute SQL: %s', query) + logging.exception('Could not execute SQL: %s', statement) elif isinstance(lyrTransaction, bool) and lyrTransaction: logging.exception('Exception occured within transaction on output ' 'layer "%s": ROLLBACK', layername) @@ -626,13 +619,12 @@ def importSources(lyr : ogr.Layer, finally: if isinstance(lyrTransaction, str): - query = 'RELEASE ' + lyrTransaction - logging.debug('ExecuteSQL(%s)', query) + statement = 'RELEASE ' + lyrTransaction try: - dso.ExecuteSQL(query) + executeSQL(dso, statement) except Exception: # pylint: disable=broad-exception-caught rv = ImportStatus.IMPORT_ERROR - logging.exception('Could not execute SQL: %s', query) + logging.exception('Could not execute SQL: %s', statement) elif isinstance(lyrTransaction, bool) and lyrTransaction: try: if lyr.CommitTransaction() != ogr.OGRERR_NONE: @@ -1051,9 +1043,8 @@ def updateLayerCache(lyr : ogr.Layer, cache : ogr.Layer, query += ' ORDER BY ' + ','.join(['t.' + escape_identifier(c) for c in sort_by]) struct_dgst : Final = struct.Struct('@qq').pack - logging.debug('ExecuteSQL(%s)', query) ds = lyr.GetDataset() - with ds.ExecuteSQL(query) as lyr2: + with executeSQL(ds, query) as lyr2: defn2 = lyr2.GetLayerDefn() assert defn2.GetFieldDefn(0).GetName() == 'hash_properties' assert defn2.GetFieldDefn(1).GetName() == 'hash_geom' @@ -1119,10 +1110,8 @@ def updateLayerCache(lyr : ogr.Layer, cache : ogr.Layer, if isinstance(lyrTransaction, str): logging.info('Updated layer "%s" has identical fingerprint %s, rolling back', layername, fingerprint.hex()[:8]) - query = 'ROLLBACK TO ' + lyrTransaction - logging.debug('ExecuteSQL(%s)', query) try: - ds.ExecuteSQL(query) + executeSQL(ds, 'ROLLBACK TO ' + lyrTransaction) except Exception: # pylint: disable=broad-exception-caught logging.exception('Could not execute SQL: %s', query) else: |
