#!/usr/bin/python3

#----------------------------------------------------------------------
# Backend utilities for the Klimatanalys Norr project (download common layers)
# Copyright © 2024 Guilhem Moulin <info@guilhem.se>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#----------------------------------------------------------------------

from os import O_RDONLY, O_WRONLY, O_CREAT, O_TRUNC, O_CLOEXEC, O_PATH, O_DIRECTORY, O_TMPFILE
import os, sys
from fcntl import flock, LOCK_EX
import logging
from time import time, monotonic as time_monotonic
import argparse
import itertools
from pathlib import Path
from email.utils import parsedate_to_datetime, formatdate
from hashlib import sha256
import requests

import common

def download_trystream(url, **kwargs):
    max_tries = 10
    f = kwargs.pop('session', requests)
    for i in itertools.count(1):
        try:
            r = f.get(url, **kwargs, stream=True)
        except (requests.Timeout, requests.ConnectionError):
            if i < max_tries:
                logging.error('timeout')
                continue
            raise
        else:
            r.raise_for_status()
            return r

def download(url, dest, dir_fd=None, headers={}, session=requests, progress=None):
    url = None if dl is None else dl.get('url', None)
    if url is None:
        logging.error('%s has no source URL, ignoring', dest)
        return
    max_size = dl.get('max-size', 2**26) # 64MiB
    logging.info('Downloading %s…', url)
    destPath = Path(dest)
    dest_tmp = str(destPath.with_stem(f'.{destPath.stem}.new'))
    try:
        # delete any leftover
        os.unlink(dest_tmp, dir_fd=dir_fd)
    except FileNotFoundError:
        pass

    start = time_monotonic()
    r = download_trystream(url, headers=headers, session=session, timeout=30)
    if r.status_code == requests.codes.not_modified:
        # XXX shouldn't we call os.utime(dest) to bump its ctime here?
        # otherwise we'll make several queries and get multiple 304
        # replies if the file is used by multiple layers
        logging.info('%s: %d Not Modified', dest, r.status_code)
        return

    body_size = r.headers.get('Content-Length', None)
    last_modified = r.headers.get('Last-Modified', None)
    if last_modified is not None:
        try:
            last_modified = parsedate_to_datetime(last_modified)
            last_modified = last_modified.timestamp()
        except ValueError:
            logging.exception('Could not parse Last-Modified value')
            last_modified = None

    size = 0
    pbar = None

    # XXX we can't use TemporaryFile as it uses O_EXCL, cf.
    # https://discuss.python.org/t/temporaryfile-contextmanager-that-allows-creating-a-directory-entry-on-success/19094/2
    fd = os.open(os.path.dirname(dest), O_WRONLY|O_CLOEXEC|O_TMPFILE, mode=0o644, dir_fd=dir_fd)
    try:
        if progress is not None:
            pbar = progress(
                total=int(body_size) if body_size is not None else float('inf'),
                leave=False,
                unit_scale=True,
                unit_divisor=1024,
                unit='B'
            )
        with os.fdopen(fd, mode='wb', closefd=False) as fp:
            for chunk in r.iter_content(chunk_size=2**16):
                chunk_size = len(chunk)
                if pbar is not None:
                    pbar.update(chunk_size)
                size += chunk_size
                if max_size is not None and size > max_size:
                    raise Exception(f'Payload exceeds max-size ({max_size})')
                fp.write(chunk)
        r = None

        if last_modified is not None:
            os.utime(fd, times=(last_modified, last_modified), follow_symlinks=True)

        # XXX unfortunately there is no way for linkat() to clobber the destination,
        # so we use a temporary file; it's racy, but thanks to O_TMPFILE better
        # (shorter race) than if we were dumping chunks in a named file descriptor
        os.link(f'/proc/self/fd/{fd}', dest_tmp, dst_dir_fd=dir_fd, follow_symlinks=True)
    finally:
        os.close(fd)
        if pbar is not None:
            pbar.close()

    try:
        # atomic rename (ensures output is never partially written)
        os.rename(dest_tmp, dest, src_dir_fd=dir_fd, dst_dir_fd=dir_fd)
    except (OSError, ValueError) as e:
        try:
            os.unlink(dest_tmp, dir_fd=dir_fd)
        finally:
            raise e

    elapsed = time_monotonic() - start
    logging.info("%s: Downloaded %s in %s (%s/s)", dest, common.format_bytes(size),
        common.format_time(elapsed), common.format_bytes(int(size/elapsed)))

if __name__ == '__main__':
    common.init_logger(app=os.path.basename(__file__), level=logging.INFO)

    parser = argparse.ArgumentParser(description='Download or update GIS layers.')
    parser.add_argument('--cachedir', default=os.curdir,
        help=f'destination directory for downloaded files (default: {os.curdir})')
    parser.add_argument('--lockdir', default=None,
        help='optional directory for lock files')
    parser.add_argument('--quiet', action='store_true',
        help='skip progress bars even when stderr is a TTY')
    parser.add_argument('--debug', action='count', default=0,
        help=argparse.SUPPRESS)
    parser.add_argument('--exit-code', default=True, action=argparse.BooleanOptionalAction,
        help='whether to exit with status 1 in case of download failures')
    parser.add_argument('groupname', nargs='*', help='group layer name(s) to process')
    args = parser.parse_args()

    if args.debug > 0:
        logging.getLogger().setLevel(logging.DEBUG)
    if args.debug > 1:
        from http.client import HTTPConnection
        HTTPConnection.debuglevel = 1
        requests_log = logging.getLogger("urllib3")
        requests_log.setLevel(logging.DEBUG)
        requests_log.propagate = True

    common.load_config(groupnames=None if args.groupname == [] else args.groupname)

    sources = []
    for name, layerdefs in common.config.get('layers', {}).items():
        for layerdef in layerdefs['sources']:
            sourcedef = layerdef.get('source', {})
            sourcedef['layername'] = name
            sources.append(sourcedef)

    if args.quiet or not sys.stderr.isatty():
        pbar = None
    else:
        from tqdm import tqdm
        pbar = tqdm

    # intentionally leave the dirfd open until the program terminates
    opendir_args = O_RDONLY|O_CLOEXEC|O_PATH|O_DIRECTORY
    destdir_fd = os.open(args.cachedir, opendir_args)
    lockdir_fd = None if args.lockdir is None else os.open(args.lockdir, opendir_args)

    sessionRequests = requests.Session()

    rv = 0
    downloads = set()
    for source in sources:
        dl = source.get('download', None)
        dl_module = None if dl is None else dl.get('module', None)
        if dl_module is None:
            fetch = download
        else:
            dl_module = __import__(dl_module)
            fetch = dl_module.download

        cache = source.get('cache', None)
        dest = None if cache is None else cache.get('path', None)
        if dest is None:
            continue

        dest = str(dest) # convert from Path()
        if dest in downloads:
            logging.info('%s was already downloaded, skipping', dest)
            continue

        headers = {}
        user_agent = common.config.get('User-Agent', None)
        if user_agent is not None:
            headers['User-Agent'] = user_agent

        try:
            # create parent directories
            destdir = os.path.dirname(dest)
            common.makedirs(destdir, mode=0o755, dir_fd=destdir_fd, exist_ok=True, logging=logging)

            # place an exclusive lock on a lockfile as the destination can be used by other layers
            # hence might be updated in parallel
            if lockdir_fd is not None:
                lockfile = sha256(dest.encode('utf-8')).hexdigest() + '.lck'
                # use O_TRUNC to bump lockfile's mtime
                lock_fd = os.open(lockfile, O_WRONLY|O_CREAT|O_TRUNC|O_CLOEXEC, mode=0o644, dir_fd=lockdir_fd)
            try:
                if lockdir_fd is not None:
                    logging.debug('flock("%s", LOCK_EX)', lockfile)
                    flock(lock_fd, LOCK_EX)
                try:
                    st = os.stat(dest, dir_fd=destdir_fd)
                except (OSError, ValueError):
                    # the file doesn't exist, or stat() failed for some reason
                    pass
                else:
                    max_age = cache.get('max-age', 6*3600) # 6h
                    if max_age is not None:
                        s = max_age + max(st.st_ctime, st.st_mtime) - time()
                        if s > 0:
                            logging.info('%s: Too young, try again in %s',
                                dest, common.format_time(s))
                            continue
                    headers['If-Modified-Since'] = formatdate(timeval=st.st_mtime, localtime=False, usegmt=True)
                fetch(dl, dest, dir_fd=destdir_fd,
                    headers=headers, session=sessionRequests,
                    progress=pbar)
                downloads.add(dest)
            finally:
                if lockdir_fd is not None:
                    os.close(lock_fd)
        except Exception:
            logging.exception('Could not download %s as %s',
                              dl.get('url', source['layername']), dest)
            if args.exit_code:
                rv = 1
    exit(rv)