From ae4b9d3fe30ff129fd3bf0ff5ca20eeafc8d86a5 Mon Sep 17 00:00:00 2001 From: Lorenz Stechauner Date: Sun, 9 Apr 2023 13:28:22 +0200 Subject: [PATCH] Refactor csv.py --- wgmaster/csv.py | 104 ++++++++++++++++++++++++-------------------- wgmaster/import.py | 12 ++--- wgmaster/migrate.py | 26 +++++------ 3 files changed, 73 insertions(+), 69 deletions(-) diff --git a/wgmaster/csv.py b/wgmaster/csv.py index 28f2cd7..cea487e 100644 --- a/wgmaster/csv.py +++ b/wgmaster/csv.py @@ -1,64 +1,72 @@ #!/bin/env python3 # -*- coding: utf-8 -*- -from typing import Iterator, Dict, Any, Optional, Tuple +from typing import Iterator, Dict, Any, Tuple import re import datetime +RE_INT = re.compile(r'-?[0-9]+') +RE_FLOAT = re.compile(r'-?[0-9]+\.[0-9]+') -def parse(filename: str) -> Iterator[Dict[str, Any]]: - def parse_line(line_str: str) -> Iterator[str]: - w = None - s = False - for ch in line_str: - if w is None: - if ch == ';': - yield '' - continue - elif ch in (' ', '\t'): - continue - w = ch - s = ch == '"' + +def cast_value(value: str) -> Any: + if value == '': + return None + elif value[0] == '"' and value[-1] == '"': + return value[1:-1] + elif value == 'T': + return True + elif value == 'F': + return False + elif RE_INT.fullmatch(value): + return int(value) + elif RE_FLOAT.fullmatch(value): + return float(value) + elif len(value) == 10 and value[4] == '-' and value[7] == '-': + return datetime.datetime.strptime(value, '%Y-%m-%d').date() + elif len(value) == 8 and value[2] == ':' and value[5] == ':': + return datetime.time.fromisoformat(value) + else: + raise RuntimeError(f'unable to infer type of value "{value}"') + + +def parse_line(line_str: str) -> Iterator[str]: + w = None + s = False + for ch in line_str: + if w is None: + if ch == ';': + yield '' continue - elif not s and ch in (';', '\n'): - yield w.strip() - w = None + elif ch in (' ', '\t'): continue - elif s and ch == '"': - s = False - w += ch - if w is not None: + w = ch + s = ch == '"' + continue + elif not s and ch in (';', '\n'): yield w.strip() + w = None + continue + elif s and ch == '"': + s = False + w += ch + if w is not None: + yield w.strip() + +def parse(filename: str) -> Iterator[Tuple]: with open(filename, 'r', encoding='utf-8') as f: - header: Optional[Tuple[str]] = None - for line in f: - if header is None: - header = tuple([e.strip() for e in line.strip().split(';')]) - continue + lines = f.__iter__() + yield tuple([part.strip() for part in next(lines).split(';')]) + for line in lines: + yield tuple([cast_value(part) for part in parse_line(line)]) - obj = {} - for i, part in enumerate(parse_line(line)): - if part == '': - part = None - elif part[0] == '"' and part[-1] == '"': - part = part[1:-1] - elif part == 'T': - part = True - elif part == 'F': - part = False - elif re.fullmatch(r'-?[0-9]+', part): - part = int(part) - elif re.fullmatch(r'-?[0-9]+\.[0-9]+', part): - part = float(part) - elif len(part) == 10 and part[4] == '-' and part[7] == '-': - part = datetime.datetime.strptime(part, '%Y-%m-%d').date() - elif len(part) == 8 and part[2] == ':' and part[5] == ':': - part = datetime.time.fromisoformat(part) - else: - raise RuntimeError(f'unable to infer type of value "{part}"') - obj[header[i]] = part - yield obj + +def parse_dict(filename: str) -> Iterator[Dict[str, Any]]: + rows = parse(filename) + header = next(rows) + for row in rows: + yield {header[i]: part for i, part in enumerate(row)} def format_row(*values) -> str: diff --git a/wgmaster/import.py b/wgmaster/import.py index 775039d..bb92a0d 100755 --- a/wgmaster/import.py +++ b/wgmaster/import.py @@ -44,17 +44,13 @@ def sqlite_regexp(pattern: str, value: Optional[str]) -> Optional[bool]: def import_csv(cur: sqlite3.Cursor, table_name: str) -> None: - rows = list(csv.parse(f'{args.dir}/{table_name}.csv')) - if len(rows) == 0: - return - - names = tuple(rows[0].keys()) - values = [tuple(row.values()) for row in rows] + rows = csv.parse(f'{args.dir}/{table_name}.csv') + names = next(rows) sql = f'INSERT INTO {table_name} ({", ".join(names)}) VALUES ({", ".join(["?"] * len(names))})' print(sql) - cur.executemany(sql, values) - print(f'{len(values)} inserts') + cur.executemany(sql, rows) + print(f'{cur.rowcount} inserts') cur.close() diff --git a/wgmaster/migrate.py b/wgmaster/migrate.py index 34ed31c..3137a8d 100755 --- a/wgmaster/migrate.py +++ b/wgmaster/migrate.py @@ -168,7 +168,7 @@ def get_bev_gst_size(kgnr: int, gstnr: str) -> Optional[int]: def parse_flaechenbindungen(in_dir: str) -> Dict[int, Dict[int, Dict[str, Any]]]: - fbs = csv.parse(f'{in_dir}/TFlaechenbindungen.csv') + fbs = csv.parse_dict(f'{in_dir}/TFlaechenbindungen.csv') members = {} for f in fbs: if f['MGNR'] not in members: @@ -258,7 +258,7 @@ def lookup_kg_name(kgnr: int) -> str: def migrate_gradation(in_dir: str, out_dir: str) -> None: global GRADATION_MAP GRADATION_MAP = {} - for g in csv.parse(f'{in_dir}/TUmrechnung.csv'): + for g in csv.parse_dict(f'{in_dir}/TUmrechnung.csv'): GRADATION_MAP[g['Oechsle']] = g['KW'] @@ -268,7 +268,7 @@ def migrate_branches(in_dir: str, out_dir: str) -> None: with open(f'{out_dir}/branch.csv', 'w+') as f: f.write('zwstid;name;country;postal_dest;address;phone_nr\n') - for b in csv.parse(f'{in_dir}/TZweigstellen.csv'): + for b in csv.parse_dict(f'{in_dir}/TZweigstellen.csv'): BRANCH_MAP[b['ZNR']] = b['Kennbst'] address = b['Straße'] postal_dest = lookup_plz(int(b['PLZ']) if b['PLZ'] else None, b['Ort'], address) @@ -282,7 +282,7 @@ def migrate_grosslagen(in_dir: str, out_dir: str) -> None: glnr = 0 with open(f'{out_dir}/wb_gl.csv', 'w+') as f: f.write('glnr;name\n') - for gl in csv.parse(f'{in_dir}/TGrosslagen.csv'): + for gl in csv.parse_dict(f'{in_dir}/TGrosslagen.csv'): glnr += 1 GROSSLAGE_MAP[gl['GLNR']] = glnr f.write(csv.format_row(glnr, gl['Bezeichnung'])) @@ -294,7 +294,7 @@ def migrate_gemeinden(in_dir: str, out_dir: str) -> None: with open(f'{out_dir}/wb_kg.csv', 'w+') as f: f.write('kgnr;glnr\n') - for g in csv.parse(f'{in_dir}/TGemeinden.csv'): + for g in csv.parse_dict(f'{in_dir}/TGemeinden.csv'): gems = lookup_gem_name(g['Bezeichnung']) GEM_MAP[g['GNR']] = gems for kgnr, gkz in gems: @@ -307,7 +307,7 @@ def migrate_reeds(in_dir: str, out_dir: str) -> None: with open(f'{out_dir}/wb_rd.csv', 'w+') as f: f.write('kgnr;rdnr;name\n') - for r in csv.parse(f'{in_dir}/TRiede.csv'): + for r in csv.parse_dict(f'{in_dir}/TRiede.csv'): name: str = r['Bezeichnung'].strip() if name.isupper(): name = name.title() @@ -325,7 +325,7 @@ def migrate_reeds(in_dir: str, out_dir: str) -> None: def migrate_attributes(in_dir: str, out_dir: str) -> None: with open(f'{out_dir}/wine_attribute.csv', 'w+') as f: f.write('attrid;name;kg_per_ha\n') - for a in csv.parse(f'{in_dir}/TSortenAttribute.csv'): + for a in csv.parse_dict(f'{in_dir}/TSortenAttribute.csv'): f.write(csv.format_row(a['SANR'], a['Attribut'], int(a['KgProHa']))) @@ -335,7 +335,7 @@ def migrate_cultivations(in_dir: str, out_dir: str) -> None: with open(f'{out_dir}/wine_cultivation.csv', 'w+') as f: f.write('cultid;name\n') - for c in csv.parse(f'{in_dir}/TBewirtschaftungsarten.csv'): + for c in csv.parse_dict(f'{in_dir}/TBewirtschaftungsarten.csv'): name: str = c['Bezeichnung'] cultid = name[0].upper() if name.isupper(): @@ -350,7 +350,7 @@ def migrate_members(in_dir: str, out_dir: str) -> None: global MEMBER_MAP MEMBER_MAP = {} - members = csv.parse(f'{in_dir}/TMitglieder.csv') + members = csv.parse_dict(f'{in_dir}/TMitglieder.csv') fbs = parse_flaechenbindungen(in_dir) with open(f'{out_dir}/member.csv', 'w+') as f_m,\ @@ -643,7 +643,7 @@ def migrate_contracts(in_dir: str, out_dir: str) -> None: f_c.write('vnr;mgnr;year_from;year_to\n') f_fb.write('vnr;kgnr;gstnr;rdnr;area;sortid;attrid;cultid\n') - for fb in csv.parse(f'{in_dir}/TFlaechenbindungen.csv'): + for fb in csv.parse_dict(f'{in_dir}/TFlaechenbindungen.csv'): if fb['Von'] is None and fb['Bis'] is None: continue parz: str = fb['Parzellennummer'] @@ -726,12 +726,12 @@ def fix_deliveries(deliveries: Iterable[Dict[str, Any]]) -> Iterable[Tuple[str, def migrate_deliveries(in_dir: str, out_dir: str) -> None: - modifiers = {m['ASNR']: m for m in csv.parse(f'{in_dir}/TAbschlaege.csv') if m['Bezeichnung']} + modifiers = {m['ASNR']: m for m in csv.parse_dict(f'{in_dir}/TAbschlaege.csv') if m['Bezeichnung']} delivery_map = {} seasons = {} branches = {} - deliveries = list(csv.parse(f'{in_dir}/TLieferungen.csv')) + deliveries = list(csv.parse_dict(f'{in_dir}/TLieferungen.csv')) delivery_dict = {d['LINR']: d for d in deliveries} fixed = fix_deliveries(deliveries) @@ -831,7 +831,7 @@ def migrate_deliveries(in_dir: str, out_dir: str) -> None: with open(f'{out_dir}/delivery_part_modifier.csv', 'w+') as f_part_mod: f_part_mod.write('year;did;dpnr;mnr\n') - for m in csv.parse(f'{in_dir}/TLieferungAbschlag.csv'): + for m in csv.parse_dict(f'{in_dir}/TLieferungAbschlag.csv'): if m['LINR'] not in delivery_map: continue nid = delivery_map[m['LINR']]