Update import.py
This commit is contained in:
@ -1,10 +1,88 @@
|
|||||||
#!/bin/env python3
|
#!/bin/env python3
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
import argparse
|
import argparse
|
||||||
|
import sqlite3
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import csv
|
||||||
|
|
||||||
|
|
||||||
|
TABLES = ['branch', 'wb_gl', 'wb_kg', 'wb_rd', 'wine_attribute', 'wine_cultivation',
|
||||||
|
'member', 'member_billing_address', 'contract', 'area_commitment',
|
||||||
|
'season', 'modifier', 'delivery', 'delivery_part', 'delivery_part_modifier',
|
||||||
|
'payment_variant', 'delivery_payment', 'member_payment']
|
||||||
|
|
||||||
|
|
||||||
|
def get_sql_files() -> List[str]:
|
||||||
|
base_dir = '..'
|
||||||
|
entries_0 = os.listdir(f'{base_dir}/sql')
|
||||||
|
dir_name = [e for e in entries_0 if e.startswith("v") and len(e) == 3][-1]
|
||||||
|
entries_data = os.listdir(f'{base_dir}/data')
|
||||||
|
|
||||||
|
files = [f'{base_dir}/sql/{dir_name}/{e}'
|
||||||
|
for e in os.listdir(f'{base_dir}/sql/{dir_name}')
|
||||||
|
if e.endswith('.sql')] + \
|
||||||
|
[f'{base_dir}/data/{e}'
|
||||||
|
for e in entries_data
|
||||||
|
if e.endswith('.sql')] + \
|
||||||
|
[f'{base_dir}/sql/{e}'
|
||||||
|
for e in entries_0
|
||||||
|
if e.endswith('.sql') and not e.startswith('sample')]
|
||||||
|
|
||||||
|
files.sort(key=lambda f: f.split('/')[-1])
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
def sqlite_regexp(pattern: str, value: Optional[str]) -> Optional[bool]:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return re.match(pattern, value) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def import_csv(cur: sqlite3.Cursor, table_name: str) -> None:
|
||||||
|
rows = list(csv.parse(f'{args.dir}/{table_name}.csv'))
|
||||||
|
if len(rows) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
names = tuple(rows[0].keys())
|
||||||
|
values = [tuple(row.values()) for row in rows]
|
||||||
|
|
||||||
|
sql = f'INSERT INTO {table_name} ({", ".join(names)}) VALUES ({", ".join(["?"] * len(names))})'
|
||||||
|
print(sql)
|
||||||
|
cur.executemany(sql, values)
|
||||||
|
|
||||||
|
cur.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('dir', type=str, metavar='DIR',
|
||||||
|
help='The directory where the migrated csv files are stored')
|
||||||
|
parser.add_argument('db', type=str, metavar='DB',
|
||||||
|
help='The sqlite database file')
|
||||||
|
parser.add_argument('-k', '--keep', action='store_true', default=False,
|
||||||
|
help='Whether the database file should be overwritten or kept')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# TODO
|
if not args.keep:
|
||||||
|
try:
|
||||||
|
os.remove(args.db)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
DB_CNX = sqlite3.connect(args.db)
|
||||||
|
DB_CNX.create_function('REGEXP', 2, sqlite_regexp)
|
||||||
|
|
||||||
|
if not args.keep:
|
||||||
|
for file_name in get_sql_files():
|
||||||
|
with open(file_name) as sql_file:
|
||||||
|
print(f'Executing {file_name}')
|
||||||
|
DB_CNX.executescript(sql_file.read())
|
||||||
|
|
||||||
|
try:
|
||||||
|
for table in TABLES:
|
||||||
|
import_csv(DB_CNX.cursor(), table)
|
||||||
|
DB_CNX.commit()
|
||||||
|
finally:
|
||||||
|
DB_CNX.close()
|
||||||
|
Reference in New Issue
Block a user