Check foreign key constraints on inport in import.py
This commit is contained in:
@ -59,6 +59,32 @@ def import_csv(cur: sqlite3.Cursor, table_name: str) -> None:
|
||||
cur.close()
|
||||
|
||||
|
||||
def check_foreign_keys(cur: sqlite3.Cursor) -> bool:
|
||||
cur.execute("PRAGMA foreign_key_check")
|
||||
rows = cur.fetchall()
|
||||
table_names = {r[0] for r in rows}
|
||||
tables = {}
|
||||
for n in table_names:
|
||||
cur.execute(f"PRAGMA foreign_key_list({n})")
|
||||
keys = cur.fetchall()
|
||||
tables[n] = {k[0]: k for k in keys}
|
||||
|
||||
cases = {}
|
||||
for row in rows:
|
||||
fk = tables[row[0]][row[3]]
|
||||
cur.execute(f"SELECT {fk[3]} FROM {row[0]} WHERE _ROWID_ = ?", (row[1],))
|
||||
value = cur.fetchall()
|
||||
string = f'{row[0]}({fk[3]}) -> {fk[2]}({fk[4]}) - {value[0][0]}'
|
||||
if string not in cases:
|
||||
cases[string] = 0
|
||||
cases[string] += 1
|
||||
for case, n in cases.items():
|
||||
print(case + (f' ({n} times)' if n > 1 else ''))
|
||||
|
||||
cur.close()
|
||||
return len(rows) == 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('dir', type=str, metavar='DIR',
|
||||
@ -94,6 +120,8 @@ if __name__ == '__main__':
|
||||
DB_CNX.execute("BEGIN")
|
||||
for table in TABLES:
|
||||
import_csv(DB_CNX.cursor(), table)
|
||||
if not check_foreign_keys(DB_CNX.cursor()):
|
||||
raise RuntimeError('foreign key constraint failed')
|
||||
DB_CNX.execute("COMMIT")
|
||||
except Exception as err:
|
||||
DB_CNX.execute("ROLLBACK")
|
||||
|
Reference in New Issue
Block a user