#!/usr/bin/python3

#----------------------------------------------------------------------
# Backend utilities for the Klimatanalys Norr project (extract/import layers)
# Copyright © 2024 Guilhem Moulin <info@guilhem.se>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#----------------------------------------------------------------------

import os
import logging
import argparse
import tempfile
import re
import math
from fnmatch import fnmatchcase
from pathlib import Path

from osgeo import gdal, ogr, osr
from osgeo.gdalconst import (
    OF_VECTOR as GDAL_OF_VECTOR,
    OF_ALL as GDAL_OF_ALL,
    OF_READONLY as GDAL_OF_READONLY,
    OF_UPDATE as GDAL_OF_UPDATE,
    OF_VERBOSE_ERROR as GDAL_OF_VERBOSE_ERROR,
    CE_None as GDAL_CE_None,
    DCAP_CREATE as GDAL_DCAP_CREATE,
    DCAP_VECTOR as GDAL_DCAP_VECTOR,
    DCAP_DEFAULT_FIELDS as GDAL_DCAP_DEFAULT_FIELDS,
    DCAP_NOTNULL_FIELDS as GDAL_DCAP_NOTNULL_FIELDS,
    DCAP_UNIQUE_FIELDS as GDAL_DCAP_UNIQUE_FIELDS,
)
import osgeo.gdalconst as gdalconst
gdal.UseExceptions()

import common

# Wrapper around gdal.MajorObject.GetMetadataItem(name)
def getMetadataItem(o, k):
    v = o.GetMetadataItem(k)
    if v is not None and isinstance(v, str):
        return v.upper() == 'YES'
    else:
        return False

# Return kwargs and driver for OpenEx()
def setOpenExArgs(option_dict, flags=0):
    kwargs = { 'nOpenFlags': GDAL_OF_VECTOR | flags }

    fmt = option_dict.get('format', None)
    if fmt is None:
        drv = None
    else:
        drv = gdal.GetDriverByName(fmt)
        if drv is None:
            raise Exception(f'Unknown driver name "{fmt}"')
        elif not getMetadataItem(drv, GDAL_DCAP_VECTOR):
            raise Exception(f'Driver "{drv.ShortName}" has no vector capabilities')
        kwargs['allowed_drivers'] = [ drv.ShortName ]

    oo = option_dict.get('open-options', None)
    if oo is not None:
        kwargs['open_options'] = [ k + '=' + str(v) for k, v in oo.items() ]
    return kwargs, drv

# Open and return the output DS.  It is created if create=False or
# create-options is a non-empty dictionary.
def openOutputDS(def_dict):
    path = def_dict['path']
    kwargs, drv = setOpenExArgs(def_dict, flags=GDAL_OF_UPDATE|GDAL_OF_VERBOSE_ERROR)
    try:
        logging.debug('OpenEx(%s, %s)', path, str(kwargs))
        return gdal.OpenEx(path, **kwargs)
    except RuntimeError as e:
        if not (gdal.GetLastErrorType() >= gdalconst.CE_Failure and
                gdal.GetLastErrorNo() == gdalconst.CPLE_OpenFailed):
            # not an open error
            raise e

        dso2 = None
        try:
            dso2 = gdal.OpenEx(path, nOpenFlags=GDAL_OF_ALL | GDAL_OF_UPDATE)
        except Exception:
            pass
        if dso2 is not None:
            # path exists but can't be open with OpenEx(path, **kwargs)
            raise e

        try:
            dso2 = gdal.OpenEx(path, nOpenFlags=GDAL_OF_ALL)
        except Exception:
            pass
        if dso2 is not None:
            # path exists but can't be open with OpenEx(path, **kwargs)
            raise e

        dsco = def_dict.get('create-options', None)
        if not def_dict.get('create', False) and dsco is None:
            # not configured for creation
            raise e
        if drv is None or not getMetadataItem(drv, GDAL_DCAP_CREATE):
            # not capable of creation
            raise e

    if 'open_options' in kwargs:
        # like ogr2ogr(1)
        logging.warning('Destination\'s open options ignored ' +
            'when creating the output datasource')

    kwargs2 = { 'eType': gdal.GDT_Unknown }
    if dsco is not None:
        kwargs2['options'] = [ k + '=' + str(v) for k, v in dsco.items() ]

    logging.debug('Create(%s, %s, eType=%s%s)', drv.ShortName, path, kwargs2['eType'],
        ', options=' + str(kwargs2['options']) if 'options' in kwargs2 else '')
    # XXX racy, a GDAL equivalent of O_EXCL would be nice
    return drv.Create(path, 0, 0, 0, **kwargs2)

# cf. ogr/ogrgeometry.cpp:OGRFromOGCGeomType()
def fromGeomTypeName(name):
    name = name.upper()

    isMeasured = False
    if name.endswith('M'):
        isMeasured = True
        name = name.removesuffix('M')

    convertTo3D = False
    if name.endswith('Z'):
        convertTo3D = True
        name = name.removesuffix('Z')

    if name == 'POINT':
        eGType = ogr.wkbPoint
    elif name == 'LINESTRING':
        eGType = ogr.wkbLineString
    elif name == 'POLYGON':
        eGType = ogr.wkbPolygon
    elif name == 'MULTIPOINT':
        eGType = ogr.wkbMultiPoint
    elif name == 'MULTILINESTRING':
        eGType = ogr.wkbMultiLineString
    elif name == 'MULTIPOLYGON':
        eGType = ogr.wkbMultiPolygon
    elif name == 'GEOMETRYCOLLECTION':
        eGType = ogr.wkbGeometryCollection
    elif name == 'CIRCULARSTRING':
        eGType = ogr.wkbCircularString
    elif name == 'COMPOUNDCURVE':
        eGType = ogr.wkbCompoundCurve
    elif name == 'CURVEPOLYGON':
        eGType = ogr.wkbCurvePolygon
    elif name == 'MULTICURVE':
        eGType = ogr.wkbMultiCurve
    elif name == 'MULTISURFACE':
        eGType = ogr.wkbMultiSurface
    elif name == 'TRIANGLE':
        eGType = ogr.wkbTriangle
    elif name == 'POLYHEDRALSURFACE':
        eGType = ogr.wkbPolyhedralSurface
    elif name == 'TIN':
        eGType = ogr.wkbTIN
    elif name == 'CURVE':
        eGType = ogr.wkbCurve
    elif name == 'SURFACE':
        eGType = ogr.wkbSurface
    else:
        eGType = ogr.wkbUnknown

    if convertTo3D:
        eGType = ogr.GT_SetZ(eGType)

    if isMeasured:
        eGType = ogr.GT_SetM(eGType)

    return eGType

# Parse geometry type, cf. ogr2ogr_lib.cpp
def parseGeomType(name):
    if name is None:
        return ogr.wkbUnknown
    name2 = name.upper()

    is3D = False
    if name2.endswith('25D'):
        name2 = name2[:-3] # alias
        is3D = True
    elif name2.endswith('Z'):
        name2 = name2[:-1]
        is3D = True

    if name2 == 'NONE':
        eGType = ogr.wkbNone
    elif name2 == 'GEOMETRY':
        eGType = ogr.wkbUnknown
    else:
        eGType = fromGeomTypeName(name2)
        if eGType == ogr.wkbUnknown:
            raise Exception(f'Unknown geometry type "{name}"')

    if eGType != ogr.wkbNone and is3D:
        eGType = ogr.GT_SetZ(eGType)

    return eGType

# cf. ogr/ogr_core.h's enum OGRFieldType;
def parseFieldType(name):
    if name is None:
        raise Exception('parseFieldType(None)')
    name2 = name.lower()
    if name2 == 'integer':
        # simple 32bit integer
        return ogr.OFTInteger
    elif name2 == 'integerlist':
        # List of 32bit integers
        return ogr.OFTIntegerList
    elif name2 == 'real':
        # Double Precision floating point
        return ogr.OFTReal
    elif name2 == 'reallist':
        # List of doubles
        return ogr.OFTRealList
    elif name2 == 'string':
        # String of ASCII chars
        return ogr.OFTString
    elif name2 == 'stringlist':
        # Array of strings
        return ogr.OFTStringList
    elif name2 == 'binary':
        # Raw Binary data
        return ogr.OFTBinary
    elif name2 == 'date':
        # Date
        return ogr.OFTDate
    elif name2 == 'time':
        # Time
        return ogr.OFTTime
    elif name2 == 'datetime':
        # Date and Time
        return ogr.OFTDateTime
    elif name2 == 'integer64':
        # Single 64bit integer
        return ogr.OFTInteger64
    elif name2 == 'integer64list':
        # List of 64bit integers
        return ogr.OFTInteger64List
    else:
        raise Exception(f'Unknown field type "{name}"')

# cf. ogr/ogr_core.h's enum OGRFieldSubType;
def parseSubFieldType(name):
    if name is None:
        raise Exception('parseSubFieldType(None)')
    name2 = name.lower()
    if name2 == 'none':
        # No subtype. This is the default value.
        return ogr.OFSTNone
    elif name2 == 'bool':
        # Boolean integer. Only valid for OFTInteger and OFTIntegerList.
        return ogr.OFSTBoolean
    elif name2 == 'int16':
        # Signed 16-bit integer. Only valid for OFTInteger and OFTIntegerList.
        return ogr.OFSTInt16
    elif name2 == 'float32':
        # Single precision (32 bit) floating point. Only valid for OFTReal and OFTRealList.
        return ogr.OFSTFloat32
    elif name2 == 'json':
        # JSON content. Only valid for OFTString.
        return ogr.OFSTJSON
    elif name2 == 'uuid':
        # UUID string representation. Only valid for OFTString.
        return ogr.OFSTUUID
    else:
        raise Exception(f'Unknown field subtype "{name}"')

# Parse timezone
TZ_RE = re.compile(r'(?:UTC\b)?([\+\-]?)([0-9][0-9]):?([0-9][0-9])', flags=re.IGNORECASE)
def parseTimeZone(tz):
    if tz is None:
        raise Exception('parseTimeZone(None)')
    tz2 = tz.lower()
    if tz2 == 'none':
        return ogr.TZFLAG_UNKNOWN
    elif tz2 == 'local':
        return ogr.TZFLAG_LOCALTIME
    elif tz2 == 'utc' or tz2 == 'gmt':
        return ogr.TZFLAG_UTC

    m = TZ_RE.fullmatch(tz)
    if m is None:
        raise Exception(f'Invalid timezone "{tz}"')
    tzSign = m.group(1)
    tzHour = int(m.group(2))
    tzMinute = int(m.group(3))
    if tzHour > 14 or tzMinute >= 60 or tzMinute % 15 != 0:
        raise Exception(f'Invalid timezone "{tz}"')
    tzFlag = tzHour*4 + int(tzMinute/15)
    if tzSign == '-':
        tzFlag = 100 - tzFlag
    else:
        tzFlag += 100
    return tzFlag

# Pretty-print timezone flag, cf.
# ogr/ogrutils.cpp:OGRGetISO8601DateTime()
def formatTZFlag(tzFlag):
    if tzFlag is None:
        raise Exception('printTimeZone(None)')
    if tzFlag == ogr.TZFLAG_UNKNOWN:
        return 'none'
    elif tzFlag == ogr.TZFLAG_LOCALTIME:
        return 'local'
    elif tzFlag == ogr.TZFLAG_UTC:
        return 'UTC'

    tzOffset = abs(tzFlag - 100) * 15;
    tzHour = int(tzOffset / 60);
    tzMinute = int(tzOffset % 60);
    tzSign = '+' if tzFlag > 100 else '-'
    return f'{tzSign}{tzHour:02}{tzMinute:02}'

# Validate layer creation options and schema.  The schema is modified in
# place with the parsed result.
# (We need the driver of the output dataset to determine capability on
# constraints.)
def validateSchema(layers, drvo=None, lco_defaults=None):
    # cache driver capabilities
    drvoSupportsDefaultFields = getMetadataItem(drvo, GDAL_DCAP_DEFAULT_FIELDS)
    drvoSupportsNotNULLFields = getMetadataItem(drvo, GDAL_DCAP_NOTNULL_FIELDS)
    drvoSupportsUniqueFields  = getMetadataItem(drvo, GDAL_DCAP_UNIQUE_FIELDS)

    for layername, layerdef in layers.items():
        create = layerdef.get('create', None)
        if create is None or len(create) < 1:
            logging.warning('Layer "%s" has no creation schema', layername)
            continue

        # prepend global layer creation options (dataset:create-layer-options)
        # and build the option=value list
        lco = create.get('options', None)
        if lco_defaults is not None or lco is not None:
            options = []
            if lco_defaults is not None:
                options += [ k + '=' + str(v) for k, v in lco_defaults.items() ]
            if lco is not None:
                options += [ k + '=' + str(v) for k, v in lco.items() ]
            create['options'] = options

        # parse geometry type
        create['geometry-type'] = parseGeomType(create.get('geometry-type', None))

        fields = create.get('fields', None)
        if fields is None:
            create['fields'] = []
        else:
            fields_set = set()
            for idx, fld_def in enumerate(fields):
                fld_name = fld_def.get('name', None)
                if fld_name is None or fld_name == '':
                    raise Exception(f'Field #{idx} has no name')
                if fld_name in fields_set:
                    raise Exception(f'Duplicate field "{fld_name}"')
                fields_set.add(fld_name)

                fld_def2 = { 'Name': fld_name }
                for k, v in fld_def.items():
                    k2 = k.lower()
                    if k2 == 'name':
                        pass
                    elif k2 == 'alternativename' or k2 == 'alias':
                        fld_def2['AlternativeName'] = v
                    elif k2 == 'comment':
                        # (WARN support added in GDAL 3.7)
                        fld_def2['Comment'] = v

                    elif k2 == 'type':
                        fld_def2['Type'] = parseFieldType(v)
                    elif k2 == 'subtype':
                        fld_def2['SubType'] = parseSubFieldType(v)
                    elif k2 == 'tz':
                        # (WARN support added in GDAL 3.8)
                        fld_def2['TZFlag'] = parseTimeZone(v)
                    elif k2 == 'width' and v is not None and isinstance(v, int):
                        fld_def2['Width'] = v
                    elif k2 == 'precision' and v is not None and isinstance(v, int):
                        fld_def2['Precision'] = v

                    # constraints
                    elif k2 == 'default':
                        if drvoSupportsDefaultFields:
                            fld_def2['Default'] = v
                        else:
                            logging.warning('%s driver lacks GDAL_DCAP_DEFAULT_FIELDS support',
                                drvo.ShortName)
                    elif k2 == 'nullable' and v is not None and isinstance(v, bool):
                        if drvoSupportsNotNULLFields:
                            fld_def2['Nullable'] = v
                        else:
                            logging.warning('%s driver lacks GDAL_DCAP_NOTNULL_FIELDS support',
                                drvo.ShortName)
                    elif k2 == 'unique' and v is not None and isinstance(v, bool):
                        if drvoSupportsUniqueFields:
                            fld_def2['Unique'] = v
                        else:
                            logging.warning('%s driver lacks GDAL_DCAP_UNIQUE_FIELDS support',
                                drvo.ShortName)
                    else:
                        raise Exception(f'Field "{fld_name}" has unknown key "{k}"')

                fields[idx] = fld_def2

# Return the decoded Spatial Reference System
def getSRS(srs_str):
    if srs_str is None:
        return
    srs = osr.SpatialReference()
    if srs_str.startswith('EPSG:'):
        code = int(srs_str.removeprefix('EPSG:'))
        srs.ImportFromEPSG(code)
    else:
        raise Exception(f'Unknown SRS {srs_str}')
    logging.debug('Default SRS: "%s" (%s)', srs.ExportToProj4(), srs.GetName())
    return srs

# Convert extent [minX, minY, maxX, maxY] into a polygon and assign the
# given SRS.  Like apps/ogr2ogr_lib.cpp, we segmentize the polygon to
# make sure it is sufficiently densified when transforming to source
# layer SRS for spatial filtering.
def getExtent(extent, srs=None):
    if extent is None:
        return

    if not (isinstance(extent, list) or isinstance(extent, tuple)) or len(extent) != 4:
        raise Exception(f'Invalid extent {extent}')
    elif srs is None:
        raise Exception('Configured extent but no SRS')

    logging.debug('Configured extent in %s: %s',
        srs.GetName(), ', '.join(map(str, extent)))

    ring = ogr.Geometry(ogr.wkbLinearRing)
    ring.AddPoint_2D(extent[0], extent[1])
    ring.AddPoint_2D(extent[2], extent[1])
    ring.AddPoint_2D(extent[2], extent[3])
    ring.AddPoint_2D(extent[0], extent[3])
    ring.AddPoint_2D(extent[0], extent[1])

    polygon = ogr.Geometry(ogr.wkbPolygon)
    polygon.AddGeometry(ring)

    # we expressed extent as minX, minY, maxX, maxY (easting/northing
    # ordered, i.e., in traditional GIS order)
    srs2 = srs.Clone()
    srs2.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER)
    polygon.AssignSpatialReference(srs2)
    polygon.TransformTo(srs)

    segment_distance_metre = 10 * 1000
    if srs.IsGeographic():
        dfMaxLength = segment_distance_metre / math.radians(srs.GetSemiMajor())
        polygon.Segmentize(dfMaxLength)
    elif srs.IsProjected():
        dfMaxLength = segment_distance_metre / srs.GetLinearUnits()
        polygon.Segmentize(dfMaxLength)

    return polygon

# Validate the output layer against the provided SRS and creation options
def validateOutputLayer(lyr, srs=None, options=None):
    ok = True

    # ensure the output SRS is equivalent
    if srs is not None:
        srs2 = lyr.GetSpatialRef()
        # cf. apps/ogr2ogr_lib.cpp
        srs_options = [
            'IGNORE_DATA_AXIS_TO_SRS_AXIS_MAPPING=YES',
            'CRITERION=EQUIVALENT'
        ]
        if not srs.IsSame(srs2, srs_options):
            logging.warning('Output layer "%s" has SRS %s,\nexpected %s',
                lyr.GetName(),
                srs2.ExportToPrettyWkt(),
                srs.ExportToPrettyWkt())
            ok = False

    if options is None:
        return ok

    layerDefn = lyr.GetLayerDefn()
    n = layerDefn.GetGeomFieldCount()
    if n != 1:
        logging.warning('Output layer "%s" has %d != 1 geometry fields', layername, n)

    geom_type1 = lyr.GetGeomType()
    geom_type2 = options['geometry-type']
    if geom_type1 != geom_type2:
        logging.warning('Output layer "%s" has geometry type #%d (%s), expected #%d (%s)',
            lyr.GetName(),
            geom_type1, ogr.GeometryTypeToName(geom_type1),
            geom_type2, ogr.GeometryTypeToName(geom_type2))
        ok = False

    fields = options.get('fields', None)
    if fields is not None:
        for fld in fields:
            fldName = fld['Name']

            idx = layerDefn.GetFieldIndex(fldName)
            if idx < 0:
                logging.warning('Output layer "%s" has no field named "%s"',
                    lyr.GetName(), fldName)
                ok = False
                continue
            defn = layerDefn.GetFieldDefn(idx)

            if 'AlternativeName' in fld:
                v1 = defn.GetAlternativeName()
                v2 = fld['AlternativeName']
                if v1 != v2:
                    logging.warning('Field "%s" has AlternativeName="%s", expected "%s"',
                        fldName, v1, v2)
                    ok = False

            if 'Comment' in fld:
                v1 = defn.GetComment()
                v2 = fld['Comment']
                if v1 != v2:
                    logging.warning('Field "%s" has Comment="%s", expected "%s"',
                        fldName, v1, v2)
                    ok = False

            if 'Type' in fld:
                v1 = defn.GetType()
                v2 = fld['Type']
                if v1 != v2:
                    logging.warning('Field "%s" has Type=%d (%s), expected %d (%s)',
                        fldName,
                        v1, ogr.GetFieldTypeName(v1),
                        v2, ogr.GetFieldTypeName(v2))
                    ok = False

            if 'SubType' in fld:
                v1 = defn.GetSubType()
                v2 = fld['SubType']
                if v1 != v2:
                    logging.warning('Field "%s" has SubType=%d (%s), expected %d (%s)',
                        fldName,
                        v1, ogr.GetFieldSubTypeName(v1),
                        v2, ogr.GetFieldSubTypeName(v2))
                    ok = False

            if 'TZFlag' in fld:
                v1 = defn.GetTZFlag()
                v2 = fld['TZFlag']
                if v1 != v2:
                    logging.warning('Field "%s" has TZFlag=%d (%s), expected %d (%s)',
                        fldName, v1, formatTZFlag(v1), v2, formatTZFlag(v2))
                    ok = False

            if 'Precision' in fld:
                v1 = defn.GetPrecision()
                v2 = fld['Precision']
                if v1 != v2:
                    logging.warning('Field "%s" has Precision=%d, expected %d',
                        fldName, v1, v2)
                    ok = False

            if 'Width' in fld:
                v1 = defn.GetWidth()
                v2 = fld['Width']
                if v1 != v2:
                    logging.warning('Field "%s" has Width=%d, expected %d',
                        fldName, v1, v2)
                    ok = False

            if 'Default' in fld:
                v1 = defn.GetDefault()
                v2 = fld['Default']
                if v1 != v2:
                    logging.warning('Field "%s" has Default="%s", expected "%s"',
                        fldName, v1, v2)
                    ok = False

            if 'Nullable' in fld:
                v1 = bool(defn.IsNullable())
                v2 = fld['Nullable']
                if v1 != v2:
                    logging.warning('Field "%s" has Nullable=%s, expected %s',
                        fldName, v1, v2)
                    ok = False

            if 'Unique' in fld:
                v1 = bool(defn.IsUnique())
                v2 = fld['Unique']
                if v1 != v2:
                    logging.warning('Field "%s" has Unique=%s, expected %s',
                        fldName, v1, v2)
                    ok = False

    return ok

# Create output layer
def createOutputLayer(ds, layername, srs=None, options=None):
    logging.info('Creating new destination layer "%s"', layername)
    geom_type = options['geometry-type']
    lco = options.get('options', None)

    drv = ds.GetDriver()
    if geom_type != ogr.wkbNone and drv.ShortName == 'PostgreSQL':
        # “Important to set to 2 for 2D layers as it has constraints on the geometry
        # dimension during loading.”
        # — https://gdal.org/drivers/vector/pg.html#layer-creation-options
        if ogr.GT_HasM(geom_type):
            if ogr.GT_HasZ(geom_type):
                dim = 'XYZM'
            else:
                dim = 'XYM'
        elif ogr.GT_HasZ(geom_type):
            dim = '3'
        else:
            dim = '2'
        if lco is None:
            lco = []
        lco = ['dim=' + dim] + lco # prepend dim=

    kwargs = { 'geom_type': geom_type }
    if srs is not None:
        kwargs['srs'] = srs
    if lco is not None:
        kwargs['options'] = lco
    logging.debug('CreateLayer(%s, geom_type="%s"%s%s)', layername,
        ogr.GeometryTypeToName(geom_type),
        ', srs="' + kwargs['srs'].GetName() + '"' if 'srs' in kwargs else '',
        ', options=' + str(kwargs['options']) if 'options' in kwargs else '')
    lyr = dso.CreateLayer(layername, **kwargs)
    if lyr is None:
        raise Exception(f'Could not create destination layer "{layername}"')

    fields = options['fields']
    if len(fields) > 0 and not lyr.TestCapability(ogr.OLCCreateField):
        raise Exception(f'Destination layer "{layername}" lacks field creation capability')

    # set up output schema
    for fld in fields:
        fldName = fld['Name']
        defn = ogr.FieldDefn()
        defn.SetName(fldName)

        if 'AlternativeName' in fld:
            v = fld['AlternativeName']
            logging.debug('Set AlternativeName=%s on output field "%s"', str(v), fldName)
            defn.SetAlternativeName(v)

        if 'Comment' in fld:
            v = fld['Comment']
            logging.debug('Set Comment=%s on output field "%s"', str(v), fldName)
            defn.SetComment(v)

        if 'Type' in fld:
            v = fld['Type']
            logging.debug('Set Type=%d (%s) on output field "%s"',
                v, ogr.GetFieldTypeName(v), fldName)
            defn.SetType(v)

        if 'SubType' in fld:
            v = fld['SubType']
            logging.debug('Set SubType=%d (%s) on output field "%s"',
                v, ogr.GetFieldSubTypeName(v), fldName)
            defn.SetSubType(v)

        if 'TZFlag' in fld:
            v = fld['TZFlag']
            logging.debug('Set TZFlag=%d (%s) on output field "%s"',
                v, formatTZFlag(v), fldName)
            defn.SetTZFlag(v)

        if 'Precision' in fld:
            v = fld['Precision']
            logging.debug('Set Precision=%d on output field "%s"', v, fldName)
            defn.SetPrecision(v)

        if 'Width' in fld:
            v = fld['Width']
            logging.debug('Set Width=%d on output field "%s"', v, fldName)
            defn.SetWidth(v)

        if 'Default' in fld:
            v = fld['Default']
            logging.debug('Set Default=%s on output field "%s"', v, fldName)
            defn.SetDefault(v)

        if 'Nullable' in fld:
            v = fld['Nullable']
            logging.debug('Set Nullable=%s on output field "%s"', v, fldName)
            defn.SetNullable(v)

        if 'Unique' in fld:
            v = fld['Unique']
            logging.debug('Set Unique=%s on output field "%s"', v, fldName)
            defn.SetUnique(v)

        if lyr.CreateField(defn, approx_ok=False) != GDAL_CE_None:
            raise Exception('Could not create field "{fldName}"')
        logging.debug('Added field "%s" to output layer "%s"', fldName, layername)

    # flush before calling StartTransaction() so we're not tryingn to
    # rollback changes on a non-existing table
    lyr.SyncToDisk()
    return lyr

# Setup output field mapping, modifying the sources dictionary in place.
def setOutputFieldMap(defn, sources):
    fieldMap = {}
    n = defn.GetFieldCount()
    for i in range(n):
        fld = defn.GetFieldDefn(i)
        fldName = fld.GetName()
        fieldMap[fldName] = i

    for source in sources:
        src = source['source']['path']
        fieldMap2 = source['import'].get('fields', None)
        if fieldMap2 is None:
            fieldMap2 = {}
        else:
            if isinstance(fieldMap2, list):
                # convert list to identity dictionary
                fieldMap2 = { fld: fld for fld in fieldMap2 }

            for ifld, ofld in fieldMap2.items():
                i = fieldMap.get(ofld, None)
                if i is None:
                    raise Exception(f'Ouput layer has no field named "{ofld}"')
                fieldMap2[ifld] = i

        source['import']['fields'] = fieldMap2

    return fieldMap

# Escape the given identifier, cf.
# swig/python/gdal-utils/osgeo_utils/samples/validate_gpkg.py:_esc_id()
def escapeIdentifier(identifier):
    if '\x00' in identifier:
        raise Exception(f'Invalid identifier "{identifier}"')
    # SQL:1999 delimited identifier
    return '"' + identifier.replace('"', '""') + '"'

# Clear the given layer (wipe all its features)
def clearLayer(ds, lyr):
    n = -1
    if lyr.TestCapability(ogr.OLCFastFeatureCount):
        n = lyr.GetFeatureCount(force=0)
        if n == 0:
            # nothing to clear, we're good
            return
    layername_esc = escapeIdentifier(lyr.GetName())

    # XXX GDAL <3.9 doesn't have lyr.GetDataset() so we pass the DS along with the layer
    drv = ds.GetDriver()
    if drv.ShortName == 'PostgreSQL':
        # https://www.postgresql.org/docs/15/sql-truncate.html
        query = 'TRUNCATE TABLE {table} CONTINUE IDENTITY RESTRICT'
        op = 'Truncating'
    else:
        query = 'DELETE FROM {table}'
        op = 'Clearing'
    logging.info('%s table %s (former feature count: %s)', op,
        layername_esc, str(n) if n >= 0 else 'unknown')
    ds.ExecuteSQL(query.format(table=layername_esc))

# Extract an archive file into the given destination directory.
def extractArchive(path, destdir, fmt=None, patterns=None, exact_matches=None):
    if fmt is None:
        suffix = path.suffix
        if suffix is None or suffix == '' or not suffix.startswith('.'):
            raise Exception(f'Could not infer archive format from "{path}"')
        fmt = suffix.removeprefix('.')

    fmt = fmt.lower()
    logging.debug('Unpacking %s archive %s into %s', fmt, path, destdir)

    if fmt == 'zip':
        from zipfile import ZipFile
        logging.debug('Opening %s as ZipFile', path)
        with ZipFile(path, mode='r') as z:
            namelist = listArchiveMembers(z.namelist(),
                patterns=patterns, exact_matches=exact_matches)
            z.extractall(path=destdir, members=namelist)
    else:
        raise Exception(f'Unknown archive format "{fmt}"')

# List archive members matching the given parterns and/or exact matches
def listArchiveMembers(namelist, patterns, exact_matches=None):
    if patterns is None and exact_matches is None:
        # if neither patterns nor exact_matches are given we'll extract the entire archive
        return namelist
    if patterns is None:
        patterns = []
    if exact_matches is None:
        exact_matches = []

    members = []
    for name in namelist:
        ok = False
        if name in exact_matches:
            # try exact matches first
            logging.debug('Listed archive member %s (exact match)', name)
            members.append(name)
            ok = True
            continue
        # if there are no exact matches, try patterns one by one in the supplied order
        for pat in patterns:
            if fnmatchcase(name, pat):
                logging.debug('Listed archive member %s (matching pattern "%s")', name, pat)
                members.append(name)
                ok = True
                break
        if not ok:
            logging.debug('Ignoring archive member %s', name)
    return members

# Import a source layer
def importSource(lyr, path=None, unar=None, args={}, cachedir=None, extent=None):
    if unar is None:
        return importSource2(lyr, str(path), args=args,
                    basedir=cachedir, extent=extent)

    if cachedir is not None:
        path = cachedir.joinpath(path)

    ds_srcpath = Path(args['path'])
    if ds_srcpath.is_absolute():
        # treat absolute paths as relative to the archive root
        logging.warning('%s is absolute, removing leading anchor', ds_srcpath)
        ds_srcpath = ds_srcpath.relative_to(ds_srcpath.anchor)
    ds_srcpath = str(ds_srcpath)

    with tempfile.TemporaryDirectory() as tmpdir:
        logging.debug('Created temporary directory %s', tmpdir)
        extractArchive(path, tmpdir,
            fmt=unar.get('format', None),
            patterns=unar.get('patterns', None),
            exact_matches=[ds_srcpath])
        return importSource2(lyr, ds_srcpath, args=args,
                    basedir=Path(tmpdir), extent=extent)

# Import a source layer (already extracted)
# This is more or less like ogr2ogr/GDALVectorTranslate() but we roll
# out our own (slower) version because GDALVectorTranslate() insists in
# calling StartTransaction() https://github.com/OSGeo/gdal/issues/3403
# while we want a single transaction for the entire desination layer,
# including truncation, source imports, and metadata changes.
def importSource2(lyr_dst, path, args={}, basedir=None, extent=None):
    kwargs, _ = setOpenExArgs(args, flags=GDAL_OF_READONLY|GDAL_OF_VERBOSE_ERROR)
    path2 = path if basedir is None else str(basedir.joinpath(path))

    logging.debug('OpenEx(%s, %s)', path2, str(kwargs))
    ds = gdal.OpenEx(path2, **kwargs)
    if ds is None:
        raise Exception(f'Could not open {path2}')

    layername = args.get('layername', None)
    if layername is None:
        idx = 0
        lyr = ds.GetLayerByIndex(idx)
        msg = '#' + str(idx)
        if lyr is not None:
            layername = lyr.GetName()
            msg += ' ("' + layername + '")'
    else:
        lyr = ds.GetLayerByName(layername)
        msg = '"' + layername + '"'
    if lyr is None:
        raise Exception(f'Could not get requested layer {msg} from {path2}')

    logging.info('Importing layer %s from "%s"', msg, path)
    canIgnoreFields = lyr.TestCapability(ogr.OLCIgnoreFields)

    srs = lyr.GetSpatialRef()
    if srs is None:
        raise Exception('Source layer has no SRS')

    srs_dst = lyr_dst.GetSpatialRef()
    if srs_dst.IsSame(srs):
        logging.debug('Both source and destination have the same SRS (%s), skipping coordinate transformation',
            srs_dst.GetName())
        ct = None
    else:
        # TODO Use GetSupportedSRSList() and SetActiveSRS() with GDAL ≥3.7.0
        # when possible, see apps/ogr2ogr_lib.cpp
        logging.debug('Creating transforming from source SRS (%s) to destination SRS (%s)',
            srs.GetName(), srs_dst.GetName())
        ct = osr.CoordinateTransformation(srs, srs_dst)
        if ct is None:
            raise Exception(f'Could not create transformation from source SRS ({srs.GetName()}) '
                + f'to destination SRS ({srs_dst.GetName()})')

    defn = lyr.GetLayerDefn()
    n = defn.GetGeomFieldCount()
    if n != 1: # TODO Add support for multiple geometry fields
        logging.warning('Source layer "%s" has %d != 1 geometry fields', layername, n)

    n = defn.GetFieldCount()
    fieldMap = [-1] * n
    fields = args['fields']
    fieldSet = set()

    for i in range(n):
        fld = defn.GetFieldDefn(i)
        fldName = fld.GetName()
        fieldMap[i] = v = fields.get(fldName, -1)
        fieldSet.add(fldName)

        if v < 0 and canIgnoreFields:
            # call SetIgnored() on unwanted source fields
            logging.debug('Set Ignored=True on output field "%s"', fldName)
            fld.SetIgnored(True)
    defn = None

    logging.debug('Field map: %s', str(fieldMap))
    for fld in fields:
        if not fld in fieldSet:
            logging.warning('Source layer "%s" has no field named "%s", ignoring', layername, fld)

    count0 = -1
    if lyr.TestCapability(ogr.OLCFastFeatureCount):
        count0 = lyr.GetFeatureCount(force=0)

    count1 = -1
    if args.get('spatial-filter', True) and extent is not None:
        if extent.GetSpatialReference().IsSame(srs):
            extent2 = extent
        else:
            extent2 = extent.Clone()
            if extent2.TransformTo(srs) != ogr.OGRERR_NONE:
                raise Exception(f'Could not transform extent {extent.ExportToWkt()} to {srs.GetName()}')

        #logging.debug('Applying extent: %s', extent2.ExportToWkt())
        lyr.SetSpatialFilter(extent2)

        if lyr.TestCapability(ogr.OLCFastFeatureCount):
            count1 = lyr.GetFeatureCount(force=0)

    if count0 >= 0:
        if count1 >= 0:
            logging.info('Source layer "%s" has %d features (of which %d within extent)',
                layername, count0, count1)
        else:
            logging.info('Source layer "%s" has %d features', layername, count0)

    defn_dst = lyr_dst.GetLayerDefn()
    eGType_dst = defn_dst.GetGeomType()
    eGType_dst_HasZ = ogr.GT_HasZ(eGType_dst)
    eGType_dst_HasM = ogr.GT_HasM(eGType_dst)
    dGeomIsUnknown = ogr.GT_Flatten(eGType_dst) == ogr.wkbUnknown

    n = 0
    mismatch = {}
    feature = lyr.GetNextFeature()
    while feature is not None:
        feature2 = ogr.Feature(defn_dst)
        feature2.SetFromWithMap(feature, False, fieldMap)

        geom = feature2.GetGeometryRef()
        if ct is not None and geom.Transform(ct) != ogr.OGRERR_NONE:
            raise Exception('Could not apply coordinate transformation')

        eGType = geom.GetGeometryType()
        if eGType != eGType_dst and not dGeomIsUnknown:
            # Promote to multi, cf. apps/ogr2ogr_lib.cpp:ConvertType()
            eGType2 = eGType
            if eGType == ogr.wkbTriangle or eGType == ogr.wkbTIN or eGType == ogr.wkbPolyhedralSurface:
                eGType2 = ogr.wkbMultiPolygon
            elif not ogr.GT_IsSubClassOf(eGType, ogr.wkbGeometryCollection):
                eGType2 = ogr.GT_GetCollection(eGType)

            eGType2 = ogr.GT_SetModifier(eGType2, eGType_dst_HasZ, eGType_dst_HasM)
            if eGType2 == eGType_dst:
                mismatch[eGType] = mismatch.get(eGType, 0) + 1
                geom = ogr.ForceTo(geom, eGType_dst)
                # TODO call MakeValid()?
            else:
                raise Exception(f'Conversion from {ogr.GeometryTypeToName(eGType)} '
                    f'to {ogr.GeometryTypeToName(eGType_dst)} not implemented')
            feature2.SetGeometryDirectly(geom)

        if lyr_dst.CreateFeature(feature2) != ogr.OGRERR_NONE:
            raise Exception(f'Could not transfer source feature #{feature.GetFID()}')

        n += 1
        feature = lyr.GetNextFeature()

    lyr = None
    logging.info('Imported %d features from source layer "%s"', n, layername)

    if len(mismatch) > 0:
        mismatches = [  str(n) + '× ' + ogr.GeometryTypeToName(t)
            for t,n in sorted(mismatch.items(), key=lambda x: x[1]) ]
        logging.info('Forced conversion to %s: %s',
            ogr.GeometryTypeToName(eGType_dst), ', '.join(mismatches))


if __name__ == '__main__':
    common.init_logger(app=os.path.basename(__file__), level=logging.INFO)

    parser = argparse.ArgumentParser(description='Extract and import GIS layers.')
    parser.add_argument('--cachedir', default=None,
        help=f'cache directory (default: {os.curdir})')
    parser.add_argument('--debug', action='count', default=0,
        help=argparse.SUPPRESS)
    parser.add_argument('groupname', nargs='*', help='group layer name(s) to process')
    args = parser.parse_args()

    if args.debug > 0:
        logging.getLogger().setLevel(logging.DEBUG)
    if args.debug > 1:
        gdal.ConfigurePythonLogging(enable_debug=True)

    common.load_config(groupnames=None if args.groupname == [] else args.groupname)

    # validate configuration
    if 'dataset' not in common.config:
        raise Exception('Configuration does not specify output dataset')

    layers = common.config.get('layers', {})
    for layername, layerdefs in layers.items():
        for idx, layerdef in enumerate(layerdefs['sources']):
            importdef = layerdef.get('import', None)
            if importdef is None:
                raise Exception(f'Output layer "{layername}" source #{idx} has no import definition')

            sourcedef = layerdef.get('source', None)
            unar = None if sourcedef is None else sourcedef.get('unar', None)
            src = None if sourcedef is None else sourcedef['cache'].get('path', None)

            ds_srcpath = importdef.get('path', None)
            if src is None and unar is None and ds_srcpath is not None:
                # fallback to importe:path if there is no unarchiving receipe
                src = Path(ds_srcpath)
            if unar is not None and ds_srcpath is None:
                raise Exception(f'Output layer "{layername}" source #{idx} has no import source path')
            if src is None:
                raise Exception(f'Output layer "{layername}" source #{idx} has no source path')
            layerdef['source'] = { 'path': src, 'unar': unar }

    # set global GDAL/OGR configuration options
    for pszKey, pszValue in common.config.get('GDALconfig', {}).items():
        logging.debug('gdal.SetConfigOption(%s, %s)', pszKey, pszValue)
        gdal.SetConfigOption(pszKey, pszValue)

    # open output dataset (possibly create it first)
    dso = openOutputDS(common.config['dataset'])

    validateSchema(layers,
        drvo=dso.GetDriver(),
        lco_defaults=common.config['dataset'].get('create-layer-options', None))

    # get configured Spatial Reference System and extent
    srs = getSRS(common.config.get('SRS', None))
    extent = getExtent(common.config.get('extent', None), srs=srs)

    cachedir = Path(args.cachedir) if args.cachedir is not None else None
    rv = 0
    for layername, layerdef in layers.items():
        sources = layerdef['sources']
        if sources is None or len(sources) < 1:
            logging.warning('Output layer "%s" has no definition, skipping', layername)
            continue

        logging.info('Processing output layer "%s"', layername)
        transaction = False
        try:
            # get output layer
            outLayerIsNotEmpty = True
            lco = layerdef.get('create', None)
            lyr = dso.GetLayerByName(layername)
            if lyr is not None:
                # TODO dso.DeleteLayer(layername) if --overwrite and dso.TestCapability(ogr.ODsCDeleteLayer)
                # (Sets OVERWRITE=YES for PostgreSQL and GPKG.)
                validateOutputLayer(lyr, srs=srs, options=lco)
                # TODO bail out if all source files are older than lyr's last_change
            elif not dso.TestCapability(ogr.ODsCCreateLayer):
                raise Exception(f'Output driver {dso.GetDriver().ShortName} does not support layer creation')
            elif lco is None or len(lco) < 1:
                raise Exception(f'Missing schema for new output layer "{layername}"')
            else:
                lyr = createOutputLayer(dso, layername, srs=srs, options=lco)
                outLayerIsNotEmpty = False

            if not lyr.TestCapability(ogr.OLCSequentialWrite):
                raise Exception(f'Output layer "{layername}" has no working CreateFeature() method')

            # setup output field mapping in the sources dictionary
            setOutputFieldMap(lyr.GetLayerDefn(), sources)

            # start transaction if possible
            if lyr.TestCapability(ogr.OLCTransactions):
                logging.debug('Starting transaction')
                transaction = lyr.StartTransaction() == ogr.OGRERR_NONE
            else:
                logging.warning('Unsafe update, output layer "%s" does not support transactions', layername)

            if outLayerIsNotEmpty:
                # clear non-empty output layer
                clearLayer(dso, lyr)

            description = layerdef.get('description', None)
            if description is not None and lyr.SetMetadataItem('DESCRIPTION', description) != GDAL_CE_None:
                logging.warning('Could not set description metadata')

            for source in sources:
                # import source layers
                importSource(lyr, **source['source'], args=source['import'],
                    cachedir=cachedir, extent=extent)

            if transaction:
                # commit transaction
                logging.debug('Committing transaction')
                transaction = False
                if lyr.CommitTransaction() != ogr.OGRERR_NONE:
                    logging.error('Could not commit transaction')
                    rv = 1

        except Exception:
            if transaction:
                logging.error('Exception occured in transaction, rolling back')
                try:
                    if lyr.RollbackTransaction() != ogr.OGRERR_NONE:
                        logging.error('Could not rollback transaction')
                except RuntimeError:
                    logging.exception('Could not rollback transaction')
            logging.exception('Could not import layer "%s"', layername)
            rv = 1

        finally:
            # close output layer
            lyr = None

    dso = None
    srs = None
    extent = None
    exit(rv)