#!/usr/bin/env python3
import argparse
import datetime
import html.parser
import http.client
import io
import os
import shutil
import sys
import tempfile
import urllib.request

import email.utils as email_utils
from typing import List, Optional, Iterable


class File:
    def __init__(self):
        self.name = None  # type: str
        self.mtime = None  # type: datetime.datetime
        self.size = None  # type: int

    def __str__(self):
        return 'File({}, {}, {})'.format(repr(self.name), self.mtime and self.mtime.isoformat(), self.size)


class Parsey(html.parser.HTMLParser):
    def __init__(self):
        super().__init__()
        self.row = 0
        self.in_row = False  # type: bool
        self.td = 0  # type: int
        self.file = File()  # type: File
        self.files = []  # type: List[File]

    def handle_starttag(self, tag: str, attrs):
        if 'tr' == tag:
            self.in_row = True
            self.row += 1
        elif 'td' == tag:
            self.td += 1

    def handle_endtag(self, tag: str):
        if 'tr' == tag:
            self.td = 0
            self.in_row = False
            if self.file.name and self.file.mtime and self.file.size:
                self.files.append(self.file)
            self.file = File()

    def handle_data(self, data: str):
        if not self.in_row:
            return

        # "parent directory" row
        if self.row == 3:
            return

        if 2 == self.td:
            self.file.name = data
        elif 3 == self.td:
            self.file.mtime = apache_time(data)
        elif 4 == self.td:
            self.file.size = apache_size(data)

    def error(self, message):
        raise Exception(message)


def apache_time(text: str) -> datetime.datetime:
    """
    >>> apache_time('2017-01-15 13:53')
    datetime.datetime(2017, 1, 15, 13, 53)
    """
    return datetime.datetime.strptime(text.strip(), '%Y-%m-%d %H:%M')


def apache_size(data):
    """
    >>> apache_size('16K')
    16*1024
    """
    suffix = data[-1].upper()
    if 'G' == suffix:
        return int(1024 * 1024 * 1024 * float(data[:-1]))
    if 'M' == suffix:
        return int(1024 * 1024 * float(data[:-1]))
    if 'K' == suffix:
        return int(1024 * float(data[:-1]))
    if ' ' == suffix:
        return int(data)

    raise Exception("can't read " + data)


class Sponge:
    def __init__(self, dest: str):
        self.fd = None  # type: io.BufferedRandom
        self.name = None  # type: str
        self.dest = dest

    def __enter__(self):
        fd_int, self.name = tempfile.mkstemp(
            dir=os.path.dirname(self.dest),
            prefix='.',
            suffix='.tmp~')
        self.fd = open(fd_int, 'w+b')
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.fd.close()
        if exc_type:
            os.unlink(self.name)
        else:
            os.rename(self.name, self.dest)


def fetch(url: str, path: str):
    with urllib.request.urlopen(url) as resp:  # type: urllib.response.HTTPResponse
        if 200 != resp.status:
            raise Exception('fetching {} failed ({}): {}'.format(url, resp.status, resp))

        with Sponge(path) as dest:
            shutil.copyfileobj(resp, dest.fd)


class Store:
    def __init__(self,
                 local_root: str,
                 folder: str,
                 base='https://tests.reproducible-builds.org/',
                 verbose: bool = False):
        assert base.endswith('/')

        self.out = os.path.realpath(local_root) + '/'
        self.folder = os.path.realpath(folder) + '/'
        self.base = base
        self.files = []  # type: List[File]
        self.verbose = verbose

        os.makedirs(self.to_local_path(self.folder), exist_ok=True)

    def mtime(self, path: str) -> Optional[datetime.datetime]:
        try:
            return datetime.datetime.fromtimestamp(os.path.getmtime(self.to_local_path(path)))
        except FileNotFoundError:
            return None

    def to_local_path(self, path: str) -> str:
        return self.out + path[1:]

    def to_url(self, path: str) -> str:
        assert path.startswith('/')
        return self.base + path[1:]

    def file_named(self, name: str) -> str:
        assert not name.startswith('/')
        return self.folder + name

    def load_index(self, max_age: datetime.timedelta):
        index = self.file_named('index.html')
        index_fetched = self.mtime(index)

        if not index_fetched or index_fetched < (datetime.datetime.now() - max_age):
            if self.verbose:
                print("info: index out of date, downloading...")
            # empty file name: not index.html on the server
            fetch(self.to_url(self.file_named('')), self.to_local_path(index))

        self.load_from(index)

    def load_from(self, path: str):
        with open(self.to_local_path(path)) as f:
            parsey = Parsey()
            if self.verbose:
                sys.stdout.write("info: parsing HTML... ")
                sys.stdout.flush()
            parsey.feed(f.read())
            if self.verbose:
                print("done.")
        self.files = parsey.files

    def exists_locally(self, file: File):
        return os.path.exists(self.file_named(file.name))

    def up_to_date(self, file: File) -> bool:
        mtime = self.mtime(self.file_named(file.name))
        if not mtime:
            return False
        return mtime >= file.mtime

    def download_many(self, files: Iterable[File]):
        parse = urllib.parse.urlparse(self.base)  # type: urllib.parse.ParseResult
        assert 'https' == parse.scheme
        conn = http.client.HTTPSConnection(parse.netloc)
        try:
            for file in files:
                path = self.file_named(file.name)

                if self.verbose:
                    print('info: fetching {}'.format(path))

                conn.request('GET', path, None, headers={
                    'User-Agent': 'fetch-logs',
                })

                resp = conn.getresponse()
                if 304 == resp.status:  # not modified
                    continue

                if 200 != resp.status:
                    raise Exception('fetching {} failed ({}): {}'.format(path, resp.status, resp))

                written_file = self.to_local_path(path)
                with Sponge(written_file) as dest:
                    shutil.copyfileobj(resp, dest.fd)

                mtime = email_utils.parsedate_to_datetime(resp.headers['Last-Modified']).timestamp()
                os.utime(written_file, times=(mtime, mtime))

                sys.stdout.write('.')
                sys.stdout.flush()

        finally:
                conn.close()

        print()

    def remove_unmatched_files(self):
        allowed_names = {file.name for file in self.files}
        for name in os.listdir(self.to_local_path(self.folder)):
            if name not in allowed_names:
                self.delete_local_file(name)

    def delete_local_file(self, name: str):
        os.unlink(self.to_local_path(self.file_named(name)))


def package_name(name: str) -> str:
    return name[0:name.index('_')]


def print_package_set(set: Iterable[File], name: str):
    print(name)
    print("---")
    print()
    for file in sorted(file.name for file in set):
        print(' * ' + file)
    print()


def main():
    allowed_types = {'dbdtxt', 'rbuild'}
    allowed_suites = {'unstable', 'testing'}
    allowed_arches = {'amd64', 'i386', 'arm64', 'armhf'}

    parser = argparse.ArgumentParser()
    parser.add_argument('-v', '--verbose', action='store_true')
    parser.add_argument('-d', '--download', action='store_true')
    parser.add_argument('-u', '--delete-outdated', action='store_true',
                        help='remove *all* files which are older than the server versions')
    parser.add_argument('-t', '--type', type=str, help='|'.join(allowed_types), default='rbuild')
    parser.add_argument('-s', '--suite', type=str, help='|'.join(allowed_suites), default='unstable')
    parser.add_argument('-a', '--arch', type=str, help='|'.join(allowed_arches), default='amd64')
    parser.add_argument('-o', '--output-dir', type=str, help='output directory path', default='logs')
    parser.add_argument('-m', '--max-size', type=str, default='1M',
                        help='maximum file size to fetch (suffixes allowed)')
    parser.add_argument('packages', nargs='*')
    args = parser.parse_args()

    if not args.download:
        parser.print_help()
        print("--download required")
        return

    if not args.type in allowed_types:
        parser.print_help()
        print("invalid type")
        return

    if args.suite not in allowed_suites:
        parser.print_help()
        print("invalid suite")

    if args.arch not in allowed_arches:
        parser.print_help()
        print("invalid arch")

    if args.max_size.isdigit():
        max_size = int(args.max_size)
    else:
        max_size = apache_size(args.max_size.strip())

    store = Store('logs/', '/debian/{}/{}/{}/'.format(args.type, args.suite, args.arch), verbose=args.verbose)

    store.load_index(datetime.timedelta(hours=1))

    wanted = set(args.packages)
    if not wanted:
        wanted = set(package_name(file.name) for file in store.files)

    to_fetch = set()
    too_big = set()
    up_to_date = set()
    for file in store.files:
        name = package_name(file.name)

        if args.delete_outdated and store.exists_locally(file) and not store.up_to_date(file):
            store.delete_local_file(file.name)

        if name not in wanted:
            continue
        if file.size > max_size:
            too_big.add(file)
            continue
        if store.up_to_date(file):
            up_to_date.add(file)
            continue
        to_fetch.add(file)

    if args.delete_outdated:
        store.remove_unmatched_files()

    if args.verbose:
        print_package_set(too_big, "too big")
        print_package_set(up_to_date, "up to date")
        print_package_set(to_fetch, "to fetch")

    print('skipped: {} (size), {} (not needed); total to download: {}'
          .format(len(too_big), len(up_to_date), len(to_fetch)))

    store.download_many(to_fetch)


if __name__ == '__main__':
    main()
