#!/usr/bin/env python3 # -*- coding: utf-8 -*- from typing import List import argparse import sqlite3 import os import datetime import utils DIR: str TABLES = ['client_parameter', 'branch', 'wb_gl', 'wb_kg', 'wb_rd', 'wine_attribute', 'wine_cultivation', 'area_commitment_type', 'member', 'member_billing_address', 'member_telephone_number', 'member_email_address', 'area_commitment', 'season', 'modifier', 'delivery', 'delivery_part', 'delivery_part_modifier', 'delivery_part_bucket', 'payment_variant', 'payment_delivery_part'] def get_sql_files() -> List[str]: base_dir = '..' entries_0 = os.listdir(f'{base_dir}/sql') dir_name = [e for e in entries_0 if e.startswith("v") and len(e) == 3][-1] entries_data = os.listdir(f'{base_dir}/data') files = [f'{base_dir}/sql/{dir_name}/{e}' for e in os.listdir(f'{base_dir}/sql/{dir_name}') if e.endswith('.sql')] + \ [f'{base_dir}/data/{e}' for e in entries_data if e.endswith('.sql')] + \ [f'{base_dir}/sql/{e}' for e in entries_0 if e.endswith('.sql') and not e.startswith('sample')] files.sort(key=lambda f: f.split('/')[-1]) return files def import_csv(cur: sqlite3.Cursor, table_name: str, verbose: bool = False) -> None: rows = utils.csv_parse(f'{DIR}/{table_name}.csv') names = next(rows) sql = f'INSERT INTO {table_name} ({", ".join(names)}) VALUES ({", ".join(["?"] * len(names))})' print(sql) if verbose: inserted = 0 for row in rows: print(row) cur.execute(sql, row) inserted += 1 print(f'{inserted} inserts') else: cur.executemany(sql, rows) print(f'{cur.rowcount} inserts') cur.close() def check_foreign_keys(cur: sqlite3.Cursor) -> bool: cur.execute("PRAGMA foreign_key_check") rows = cur.fetchall() table_names = {r[0] for r in rows} tables = {} for n in table_names: cur.execute(f"PRAGMA foreign_key_list({n})") keys = cur.fetchall() tables[n] = {k[0]: k for k in keys} cases = {} for row in rows: fk = tables[row[0]][row[3]] cur.execute(f"SELECT {fk[3]} FROM {row[0]} WHERE _ROWID_ = ?", (row[1],)) value = cur.fetchall() string = f'{row[0]}({fk[3]}) -> {fk[2]}({fk[4]}) - {value[0][0]}' if string not in cases: cases[string] = 0 cases[string] += 1 for case, n in cases.items(): print(case + (f' ({n} times)' if n > 1 else '')) cur.close() return len(rows) == 0 def main() -> None: global DIR parser = argparse.ArgumentParser() parser.add_argument('dir', type=str, metavar='DIR', help='The directory where the migrated csv files are stored') parser.add_argument('db', type=str, metavar='DB', help='The sqlite database file') parser.add_argument('-k', '--keep', action='store_true', default=False, help='Whether the database file should be overwritten or kept') parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Log every inserted row') args = parser.parse_args() DIR = args.dir if not args.keep: try: os.remove(args.db) except FileNotFoundError: pass sqlite3.register_adapter(datetime.date, lambda d: d.strftime('%Y-%m-%d')) sqlite3.register_adapter(datetime.time, lambda t: t.strftime('%H:%M:%S')) cnx = sqlite3.connect(args.db) cnx.create_function('REGEXP', 2, utils.sqlite_regexp, deterministic=True) if not args.keep: for file_name in get_sql_files(): with open(file_name, encoding='utf-8') as sql_file: print(f'Executing {file_name}') cnx.executescript(sql_file.read()) ver = cnx.execute("PRAGMA schema_version").fetchall()[0][0] try: cnx.isolation_level = None # Member predecessors may refer to a higher MgNr cnx.execute("PRAGMA foreign_keys = OFF") cnx.execute("BEGIN") for table in TABLES: import_csv(cnx.cursor(), table, args.verbose) if not check_foreign_keys(cnx.cursor()): raise RuntimeError('foreign key constraint failed') cnx.execute("COMMIT") cnx.execute("VACUUM") cnx.execute(f"PRAGMA schema_version = {ver}") except Exception as err: cnx.execute("ROLLBACK") raise err finally: cnx.execute("PRAGMA foreign_keys = ON") cnx.close() if __name__ == '__main__': main()