Compare commits

...

3 Commits

View File

@ -6,6 +6,7 @@ from typing import Callable, Optional
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
import argparse import argparse
import datetime import datetime
import time
import traceback import traceback
import re import re
import base64 import base64
@ -13,12 +14,510 @@ import json
import sqlite3 import sqlite3
import urllib.parse import urllib.parse
import gzip import gzip
import hashlib
import hmac
VERSION: str = '0.0.3' VERSION: str = '0.0.3'
CNX: sqlite3.Cursor CNX: sqlite3.Cursor
USER_FILE: str USER_FILE: str
JWT_SECRET: bytes
JWT_ISSUER: str
JWT_INVALIDATE_BEFORE: int = 0
JWT_USER_INVALIDATE_BEFORE: dict[str, int] = {}
class HttpError(BaseException):
status_code: int
def __init__(self, status_code: int, message: str):
self.status_code = status_code
super().__init__(message)
class BadRequestError(HttpError):
def __init__(self, message: str = 'Bad request'):
super().__init__(400, message)
class UnauthorizedError(HttpError):
def __init__(self, message: str = 'Unauthorized'):
super().__init__(401, message)
class ForbiddenError(HttpError):
def __init__(self, message: str = 'Forbidden'):
super().__init__(403, message)
class NotFoundError(HttpError):
def __init__(self, message: str = 'Not found'):
super().__init__(404, message)
class MethodNotAllowedError(HttpError):
def __init__(self, message: str = 'Method not allowed'):
super().__init__(405, message)
class Filter:
def __init__(self, name: str, values: list[int] or list[str] = None):
self.name = name
self.values = values
def is_int(self) -> bool:
return type(self.values[0]) is int
def is_str(self) -> bool:
return type(self.values[0]) is str
def is_single(self) -> bool:
return self.values is None
def to_sql_list(self) -> str:
if self.is_int():
return ', '.join(str(v) for v in self.values)
else:
return ', '.join(f"'{v}'" for v in self.values)
def __repr__(self) -> str:
if self.is_single():
return self.name
elif self.name == 'kgnr':
return f'kgnr={";".join(f"{v:05}" for v in self.values)}'
return f'{self.name}={";".join(str(v) for v in self.values)}'
def __str__(self) -> str:
return self.__repr__()
def __eq__(self, other) -> bool:
return self.__repr__() == other.__repr__()
@staticmethod
def from_str(string: str) -> Filter:
f = string.split('=', 1)
if len(f) == 2:
ps = f[1].split(';')
is_digit = all(p.isdigit() for p in ps)
return Filter(f[0], [int(p) for p in ps] if is_digit else ps)
return Filter(f[0])
def sqlite_regexp(pattern: str, value: Optional[str]) -> Optional[bool]:
return re.match(pattern, value) is not None if value is not None else None
def kmw_to_oe(kmw: float) -> float:
return kmw * (4.54 + 0.022 * kmw) if kmw is not None else None
def jdmp(value, is_bool: bool = False) -> str:
if is_bool and value:
return ' true'
elif is_bool and not value:
return 'false'
return json.dumps(value, ensure_ascii=False)
def check_password(stored_pwd: str, check_pwd: str) -> bool:
return stored_pwd == check_pwd
def check_user_password(username: str, password: str) -> tuple[str, str]:
with open(USER_FILE, 'r') as file:
for line in file:
(u, r, i, p) = line.strip().split(':', 3)
if u == username and check_password(p, password):
return u, r
raise UnauthorizedError()
def get_delivery_schedule_filter_clauses(filters: list[Filter]) -> list[str]:
clauses = []
for f in filters:
if f.name == 'year' and f.is_int():
clauses.append(f"s.year IN ({f.to_sql_list()})")
elif f.name == 'sortid' and f.is_str() and all(len(v) == 2 and v.isalpha() and v.isupper() for v in f.values):
clauses.append(f"v.sortid IN ({f.to_sql_list()})")
elif f.name == 'date' and f.is_str() and all(re.match(r'[0-9]{4}-[0-9]{2}-[0-9]{2}', v) is not None for v in f.values):
clauses.append(f"s.date IN ({f.to_sql_list()})")
else:
raise BadRequestError(f"Invalid filter '{f}'")
return clauses
class ElwigApi(BaseHTTPRequestHandler):
def send(self, data: str, status_code: int = 200, url: str = None) -> None:
raw = data.encode('utf-8')
self.send_response(status_code)
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Headers', 'Authorization')
if self.path in ('/auth',):
self.send_header('Access-Control-Allow-Methods', 'HEAD, GET, POST, OPTIONS')
else:
self.send_header('Access-Control-Allow-Methods', 'HEAD, GET, OPTIONS')
if 300 <= status_code < 400 and status_code != 304 and url:
self.send_header('Location', url)
elif status_code == 401:
self.send_header('WWW-Authenticate', 'Basic realm=Elwig')
if 'Accept-Encoding' in self.headers and len(data) > 64:
accept_encoding = [e.strip() for e in self.headers.get('Accept-Encoding').split(',')]
if 'gzip' in accept_encoding:
raw = gzip.compress(raw)
self.send_header('Content-Encoding', 'gzip')
self.send_header('Content-Type', 'application/json; charset=UTF-8')
self.send_header('Content-Length', str(len(raw)))
self.end_headers()
if self.request.type != 'HEAD' and self.request.type != 'OPTIONS':
self.wfile.write(raw)
def error(self, status_code: int, message: str = None) -> None:
self.send(f'{{"message":{jdmp(message)}}}\n', status_code=status_code)
def see_other(self, url: str) -> None:
self.send(f'{{"url": {jdmp(url)}}}\n', status_code=303, url=url)
def authorize(self) -> tuple[str, str, str]:
auth = self.headers.get('Authorization')
if auth and auth.startswith('Basic '):
u, r = ElwigApi.authorize_basic(auth[6:])
return u, r, 'Basic'
elif auth and auth.startswith('Bearer '):
u, r = ElwigApi.authorize_bearer(auth[7:])
return u, r, 'Bearer'
raise UnauthorizedError()
@staticmethod
def authorize_basic(auth: str) -> tuple[str, str]:
try:
username, password = base64.b64decode(auth.strip() + '==').decode('utf-8').split(':', 1)
except:
raise BadRequestError('Invalid Authorization header')
return check_user_password(username, password)
@staticmethod
def authorize_bearer(token: str) -> tuple[str, str]:
try:
hdr_r, payload_r, sig = token.strip().split('.')
hdr = json.loads(base64.urlsafe_b64decode(hdr_r + '==').decode('utf-8'))
payload = json.loads(base64.urlsafe_b64decode(payload_r + '==').decode('utf-8'))
if hdr['typ'] != 'JWT':
raise ValueError()
mac = hmac.new(JWT_SECRET, digestmod={
'HS224': hashlib.sha224,
'HS256': hashlib.sha256,
'HS384': hashlib.sha384,
'HS512': hashlib.sha512,
}[hdr['alg']])
mac.update((hdr_r + '.' + payload_r).encode('ascii'))
digest = mac.digest()
except Exception:
raise BadRequestError('Invalid Authorization header')
try:
if digest != base64.urlsafe_b64decode(sig + '=='):
raise UnauthorizedError('Invalid JWT signature')
elif payload['iss'] != JWT_ISSUER:
raise UnauthorizedError('Invalid JWT issuer')
elif 'exp' in payload and payload['exp'] < int(time.time()):
raise UnauthorizedError('JWT token expired')
elif 'nbf' in payload and payload['nbf'] > int(time.time()):
raise UnauthorizedError('JWT token not yet valid')
elif payload['iat'] < JWT_INVALIDATE_BEFORE:
raise UnauthorizedError('Invalidated JWT token')
elif payload['iat'] < JWT_USER_INVALIDATE_BEFORE.get(payload['sub'], 0):
raise UnauthorizedError('Invalidated JWT token')
return payload['sub'], payload['rol']
except Exception:
raise UnauthorizedError('Invalid JWT token')
@staticmethod
def issue_jwt(username: str, role: str) -> str:
hdr = base64.urlsafe_b64encode(b'{"typ":"JWT","alg":"HS256"}').strip(b'=')
payload = base64.urlsafe_b64encode(json.dumps({
'iss': JWT_ISSUER,
'sub': username,
'rol': role,
'iat': int(time.time()),
}, ensure_ascii=False, separators=(',', ':')).encode('utf-8')).strip(b'=')
mac = hmac.new(JWT_SECRET, digestmod=hashlib.sha256)
mac.update(hdr + b'.' + payload)
sig = base64.urlsafe_b64encode(mac.digest()).strip(b'=')
return (hdr + b'.' + payload + b'.' + sig).decode('ascii')
def exec_collection(self, sql_query: str, fmt: Callable, filters: list[Filter],
offset: int = None, limit: int = None,
distinct: tuple[[str], [int]] = None, sub_fmt: Callable = None) -> None:
with_clause = re.findall(r'(WITH .*?\))[\s\n]*SELECT', sql_query, flags=re.DOTALL)
if len(with_clause) > 0 and '.*' not in with_clause[0]:
with_clause = with_clause[0]
count_query = sql_query.replace(with_clause, '')
else:
with_clause = None
count_query = sql_query
count = f"""SELECT COUNT(DISTINCT {" || '|' || ".join(distinct[0])}) FROM""" if distinct else "SELECT COUNT(*) FROM"
count_query = re.sub(r"SELECT [^*]+? FROM", count, count_query, count=1, flags=re.DOTALL)
count_query = re.sub(r"(OFFSET|LIMIT) [0-9-]+", '', count_query)
count_query = re.sub(r"GROUP BY .*", '', count_query)
if with_clause:
count_query = with_clause + ' ' + count_query
count = CNX.execute(count_query).fetchone()
count = count[0] if count is not None else 0
if limit is not None:
if "LIMIT " in sql_query:
sql_query = re.sub(r"LIMIT [0-9-]+", f"LIMIT {limit}", sql_query)
else:
sql_query += f" LIMIT {limit}"
if offset is not None:
if "OFFSET " in sql_query:
sql_query = re.sub(r"OFFSET [0-9-]+", f"OFFSET {offset}", sql_query)
else:
sql_query += f" OFFSET {offset}"
rows = CNX.execute(sql_query)
data = (f'''{{"filters":[{','.join(f'{{"filter":{jdmp(str(f))}}}' for f in filters)}],'''
f'"total":{count},"offset":{offset or 0},"limit":{jdmp(limit)},'
f'"data":[')
first, first_, cur, last = True, True, None, None
for r in rows or []:
cur = tuple([r[i] for i in distinct[1]]) if distinct else None
if not distinct or cur != last:
first_ = True
if first:
first = False
else:
if distinct and sub_fmt:
data += '\n ]}'
data += ','
data += f'\n ' + fmt(r)
if distinct and sub_fmt:
if first_:
data += '['
first_ = False
else:
data += ','
data += f'\n ' + sub_fmt(r)
last = cur
if distinct and sub_fmt and not first:
data += '\n ]}'
data += '\n]}\n'
self.send(data)
def do_GET_delivery_schedules(self, filters: list[Filter],
offset: int = None, limit: int = None,
order: str = None) -> None:
clauses = get_delivery_schedule_filter_clauses(filters)
sql = f"""
WITH announcements
AS (SELECT year, dsnr, SUM(weight) AS weight
FROM delivery_announcement
GROUP BY year, dsnr)
SELECT s.year, s.dsnr, s.date, s.description, s.max_weight, s.cancelled,
COALESCE(a.weight, 0) AS announced_weight,
COALESCE(SUM(p.weight), 0) AS delivered_weight,
STRFTIME('%Y-%m-%dT%H:%M:%SZ', DATETIME(s.ancmt_from, 'unixepoch')),
STRFTIME('%Y-%m-%dT%H:%M:%SZ', DATETIME(s.ancmt_to, 'unixepoch')),
b.zwstid, b.name,
s.attrid, s.cultid
FROM delivery_schedule s
LEFT JOIN branch b ON b.zwstid = s.zwstid
LEFT JOIN announcements a ON (a.year, a.dsnr) = (s.year, s.dsnr)
LEFT JOIN delivery_schedule_wine_variety v ON (v.year, v.dsnr) = (s.year, s.dsnr)
LEFT JOIN delivery d ON (d.date, d.zwstid) = (s.date, s.zwstid)
LEFT JOIN delivery_part p ON (p.year, p.did) = (d.year, d.did) AND p.sortid = v.sortid
"""
if len(clauses) > 0:
sql += f" WHERE {' AND '.join(clauses)}"
sql += " GROUP BY s.year, s.dsnr"
sql += " ORDER BY s.year, s.date, s.zwstid, s.description, s.dsnr"
rows1 = CNX.execute("""
SELECT date, zwstid, cultid, SUM(weight)
FROM delivery d
JOIN delivery_part p ON (p.year, p.did) = (d.year, d.did)
WHERE (d.date, d.zwstid, COALESCE(p.cultid, '')) IN
(SELECT date, zwstid, COALESCE(cultid, '') FROM delivery_schedule GROUP BY date, zwstid, cultid HAVING COUNT(*) = 1)
GROUP BY date, zwstid, cultid
""")
days1 = {(r[0], r[1], r[2]): r[3] for r in rows1}
rows2 = CNX.execute("""
SELECT date, zwstid, attrid, SUM(weight)
FROM delivery d
JOIN delivery_part p ON (p.year, p.did) = (d.year, d.did)
WHERE (d.date, d.zwstid, COALESCE(p.attrid, '')) IN
(SELECT date, zwstid, COALESCE(attrid, '') FROM delivery_schedule GROUP BY date, zwstid, attrid HAVING COUNT(*) = 1)
GROUP BY date, zwstid, attrid
""")
days2 = {(r[0], r[1], r[2]): r[3] for r in rows2}
self.exec_collection(
sql,
lambda r: f'{{"year":{r[0]:4},"dsnr":{r[1]:2},"date":"{r[2]}",'
f'"branch":{{"zwstid":{jdmp(r[10])},"name":{jdmp(r[11]):20}}},'
f'"description":{jdmp(r[3]):50},'
f'"max_weight":{jdmp(r[4]):>6},'
f'"is_cancelled":{jdmp(r[5], is_bool=True)},'
f'"announced_weight":{r[6]:6},'
f'"delivered_weight":{days1.get((r[2], r[10], r[13]), days2.get((r[2], r[10], r[12]), r[7] or 0)):6},'
f'"announcement_from":{jdmp(r[8])},'
f'"announcement_to":{jdmp(r[9])}}}',
filters, offset, limit, distinct=(['s.year', 's.dsnr'], [1, 2]))
def do_HEAD(self) -> None:
self.do_GET()
def do_OPTIONS(self) -> None:
self.send('')
def do_GET(self) -> None:
try:
if self.path == '/':
openapi_json = f'https://{self.headers.get("Host", "localhost")}/elwig/api/v1/openapi.json'
self.see_other(f'https://validator.swagger.io/?url={openapi_json}')
return
elif self.path == '/openapi.json':
self.send(OPEN_API_DOC)
return
username, role, auth_method = self.authorize()
parts = self.path.split('?', 1)
if len(parts) == 1:
path, query = parts[0], {}
else:
path, query = parts[0], {urllib.parse.unquote(s[0]): urllib.parse.unquote(s[-1])
for s in [p.split('=', 1) for p in parts[1].split('&')]}
filters = [Filter.from_str(f) for f in (query['filters'].split(',') if 'filters' in query else [])]
try:
offset = int(query['offset']) if 'offset' in query else None
limit = int(query['limit']) if 'limit' in query else None
except ValueError:
raise BadRequestError('Invalid integer value in query')
order = query['order'] if 'order' in query else None
if path == '/auth':
if auth_method == 'Bearer':
raise ForbiddenError('Tokens must not be renewed')
self.send( f'{{"token":{jdmp(ElwigApi.issue_jwt(username, role))}}}\n')
elif path == '/wine/varieties':
self.exec_collection(
"SELECT sortid, type, name, comment FROM wine_variety",
lambda r: f'{{"sortid":{jdmp(r[0])},"type":{jdmp(r[1])},"name":{jdmp(r[2])},"comment":{jdmp(r[3])}}}',
[], offset, limit)
elif path == '/wine/quality_levels':
self.exec_collection(
"SELECT qualid, name, min_kmw, predicate FROM wine_quality_level",
lambda r: f'{{"qualid":{jdmp(r[0])},"name":{jdmp(r[1]):22},"min_kmw":{jdmp(r[2])},"is_predicate":{jdmp(r[3], is_bool=True)}}}',
[], offset, limit)
elif path == '/wine/attributes':
self.exec_collection(
"SELECT attrid, name FROM wine_attribute",
lambda r: f'{{"attrid":{jdmp(r[0]):4},"name":{jdmp(r[1])}}}',
[], offset, limit)
elif path == '/wine/cultivations':
self.exec_collection(
"SELECT cultid, name, description FROM wine_cultivation",
lambda r: f'{{"cultid":{jdmp(r[0]):5},"name":{jdmp(r[1])},"description":{jdmp(r[2])}}}',
[], offset, limit)
elif path == '/modifiers':
self.exec_collection(
"SELECT year, modid, name, ordering FROM modifier",
lambda r: f'{{"year":{jdmp(r[0])},"modid":{jdmp(r[1]):5},"name":{jdmp(r[2]):18},"ordering":{jdmp(r[3])}}}',
[], offset, limit)
elif path == '/delivery_schedules':
self.do_GET_delivery_schedules(filters, offset, limit, order)
else:
raise NotFoundError('Invalid path')
except HttpError as e:
self.error(e.status_code, str(e))
except Exception as e:
traceback.print_exception(e)
self.error(500, str(e))
def do_POST(self) -> None:
try:
parts = self.path.split('?', 1)
path = parts[0]
if path == '/auth':
content_type = self.headers.get('Content-Type')
content_len = self.headers.get('Content-Length')
if content_len is None:
raise HttpError(411, 'Length required')
content_len = int(content_len)
if content_len > 4096 or content_len < 0:
raise HttpError(413, 'Content too large')
elif content_type == 'application/x-www-form-urlencoded':
payload = self.rfile.read(content_len)
try:
data = {urllib.parse.unquote(s[0]): urllib.parse.unquote(s[-1])
for s in [p.split(b'=', 1) for p in payload.split(b'&')]}
username, password = data['username'], data['password']
except Exception:
raise BadRequestError('Invalid URL encoded payload')
elif content_type == 'application/json':
payload = self.rfile.read(content_len)
try:
data = json.loads(payload.decode('utf-8'))
username, password = data['username'], data['password']
except Exception:
raise BadRequestError('Invalid JSON object')
else:
raise HttpError(415, 'Unsupported media type')
username, role = check_user_password(username, password)
self.send(f'{{"token":{jdmp(ElwigApi.issue_jwt(username, role))}}}\n')
elif path in ('/wine/varieties', '/wine/quality_levels', '/wine/attributes', '/wine/cultivations', '/modifiers', '/delivery_schedules'):
raise MethodNotAllowedError()
else:
raise NotFoundError('Invalid path')
except HttpError as e:
self.error(e.status_code, str(e))
except Exception as e:
traceback.print_exception(e)
self.error(500, str(e))
def main() -> None:
global CNX
global USER_FILE
global JWT_ISSUER
global JWT_SECRET
global JWT_INVALIDATE_BEFORE
sqlite3.register_adapter(datetime.date, lambda d: d.strftime('%Y-%m-%d'))
sqlite3.register_adapter(datetime.time, lambda t: t.strftime('%H:%M:%S'))
parser = argparse.ArgumentParser()
parser.add_argument('db', type=str, metavar='DB')
parser.add_argument('jwt_file', type=str, metavar='JWT-FILE')
parser.add_argument('user_file', type=str, metavar='USER-FILE')
parser.add_argument('-p', '--port', type=int, default=8080)
args = parser.parse_args()
jwt_file = args.jwt_file
with open(jwt_file, 'rb') as file:
iss, dt, JWT_SECRET = file.readline().strip().split(b':', 2)
JWT_ISSUER = iss.decode('utf-8')
JWT_INVALIDATE_BEFORE = int(datetime.datetime.fromisoformat(dt.decode('ascii')).timestamp()) if dt else 0
USER_FILE = args.user_file
with open(USER_FILE, 'r') as file:
for line in file:
(u, r, i, p) = line.strip().split(':', 3)
JWT_USER_INVALIDATE_BEFORE[u.strip()] = int(datetime.datetime.fromisoformat(i).timestamp()) if i else 0
CNX = sqlite3.connect(f'file:{args.db}?mode=ro', uri=True)
CNX.create_function('REGEXP', 2, sqlite_regexp, deterministic=True)
server = HTTPServer(('localhost', args.port), ElwigApi)
print(f'Listening on http://localhost:{args.port}')
try:
server.serve_forever()
except InterruptedError | KeyboardInterrupt:
print()
server.server_close()
print('Good bye!')
OPEN_API_DOC: str = '''{ OPEN_API_DOC: str = '''{
"openapi": "3.1.0", "openapi": "3.1.0",
@ -33,9 +532,65 @@ OPEN_API_DOC: str = '''{
"url": "https://wgm.elwig.at/elwig/api/v1", "url": "https://wgm.elwig.at/elwig/api/v1",
"description": "WG Matzen" "description": "WG Matzen"
}], }],
"components": {"securitySchemes": {"basicAuth": {"type": "http", "scheme": "basic"}}}, "components": {
"security": [{"basicAuth": []}], "securitySchemes": {
"basicAuth": {"type": "http", "scheme": "basic"},
"bearerAuth": {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
}
},
"security": [{"basicAuth": []}, {"bearerAuth": []}],
"paths": { "paths": {
"/auth": {
"get": {
"tags": ["Authentication"],
"summary": "Authentication",
"security": [{"basicAuth": []}],
"responses": {
"200": {
"description": "Success",
"content": {
"application/json": {
"schema": {"type": "object", "required": ["token"], "properties": {"token": {"type": "string"}}},
"examples": {"simple": {"value": {"token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJleGFtcGxlLmNvbSIsInN1YiI6InRlc3QiLCJyb2wiOiJleHRlcm5hbCIsImlhdCI6MTc0NTgzMTQzN30.VMLz20aWI8nSd7ocT2W750Cy80OJs8OMiQCtq5Df0rE"}}}
}
}
},
"400": {"description": "Bad Request", "content": {"application/json": {"schema": {"type": "object", "required": ["message"], "properties": {"message": {"oneOf": [{"type": "string"}, {"type": "null"}]}}}, "examples": {"simple": {"value": {"message": "Invalid Authorization header"}}}}}},
"401": {"description": "Unauthorized", "content": {"application/json": {"schema": {"type": "object", "required": ["message"], "properties": {"message": {"oneOf": [{"type": "string"}, {"type": "null"}]}}}, "examples": {"simple": {"value": {"message": "Unauthorized"}}}}}},
"500": {"description": "Internal Server Error", "content": {"application/json": {"schema": {"type": "object", "required": ["message"], "properties": {"message": {"oneOf": [{"type": "string"}, {"type": "null"}]}}}, "examples": {"simple": {"value": {"message": "Unknown error"}}}}}}
}
},
"post": {
"tags": ["Authentication"],
"summary": "Authentication",
"security": [],
"requestBody": {
"required": true,
"content": {
"application/json": {
"schema": {"type": "object", "required": ["username", "password"], "properties": {"username": {"type": "string"}, "password": {"type": "string"}}}
},
"application/x-www-form-urlencoded": {
"schema": {"type": "object", "required": ["username", "password"], "properties": {"username": {"type": "string"}, "password": {"type": "string"}}}
}
}
},
"responses": {
"200": {
"description": "Success",
"content": {
"application/json": {
"schema": {"type": "object", "required": ["token"], "properties": {"token": {"type": "string"}}},
"examples": {"simple": {"value": {"token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJleGFtcGxlLmNvbSIsInN1YiI6InRlc3QiLCJyb2wiOiJleHRlcm5hbCIsImlhdCI6MTc0NTgzMTQzN30.VMLz20aWI8nSd7ocT2W750Cy80OJs8OMiQCtq5Df0rE"}}}
}
}
},
"400": {"description": "Bad Request", "content": {"application/json": {"schema": {"type": "object", "required": ["message"], "properties": {"message": {"oneOf": [{"type": "string"}, {"type": "null"}]}}}, "examples": {"simple": {"value": {"message": "Invalid Authorization header"}}}}}},
"401": {"description": "Unauthorized", "content": {"application/json": {"schema": {"type": "object", "required": ["message"], "properties": {"message": {"oneOf": [{"type": "string"}, {"type": "null"}]}}}, "examples": {"simple": {"value": {"message": "Unauthorized"}}}}}},
"500": {"description": "Internal Server Error", "content": {"application/json": {"schema": {"type": "object", "required": ["message"], "properties": {"message": {"oneOf": [{"type": "string"}, {"type": "null"}]}}}, "examples": {"simple": {"value": {"message": "Unknown error"}}}}}}
}
}
},
"/wine/varieties": { "/wine/varieties": {
"get": { "get": {
"tags": ["Base Data"], "tags": ["Base Data"],
@ -413,353 +968,6 @@ OPEN_API_DOC: str = '''{
'''.replace('[VERSION]', VERSION) '''.replace('[VERSION]', VERSION)
class BadRequestError(Exception):
pass
class Filter:
def __init__(self, name: str, values: list[int] or list[str] = None):
self.name = name
self.values = values
def is_int(self) -> bool:
return type(self.values[0]) is int
def is_str(self) -> bool:
return type(self.values[0]) is str
def is_single(self) -> bool:
return self.values is None
def to_sql_list(self) -> str:
if self.is_int():
return ', '.join(str(v) for v in self.values)
else:
return ', '.join(f"'{v}'" for v in self.values)
def __repr__(self) -> str:
if self.is_single():
return self.name
elif self.name == 'kgnr':
return f'kgnr={";".join(f"{v:05}" for v in self.values)}'
return f'{self.name}={";".join(str(v) for v in self.values)}'
def __str__(self) -> str:
return self.__repr__()
def __eq__(self, other) -> bool:
return self.__repr__() == other.__repr__()
@staticmethod
def from_str(string: str) -> Filter:
f = string.split('=', 1)
if len(f) == 2:
ps = f[1].split(';')
is_digit = all(p.isdigit() for p in ps)
return Filter(f[0], [int(p) for p in ps] if is_digit else ps)
return Filter(f[0])
def sqlite_regexp(pattern: str, value: Optional[str]) -> Optional[bool]:
return re.match(pattern, value) is not None if value is not None else None
def kmw_to_oe(kmw: float) -> float:
return kmw * (4.54 + 0.022 * kmw) if kmw is not None else None
def jdmp(value, is_bool: bool = False) -> str:
if is_bool and value:
return ' true'
elif is_bool and not value:
return 'false'
return json.dumps(value, ensure_ascii=False)
def get_delivery_schedule_filter_clauses(filters: list[Filter]) -> list[str]:
clauses = []
for f in filters:
if f.name == 'year' and f.is_int():
clauses.append(f"s.year IN ({f.to_sql_list()})")
elif f.name == 'sortid' and f.is_str() and all(len(v) == 2 and v.isalpha() and v.isupper() for v in f.values):
clauses.append(f"v.sortid IN ({f.to_sql_list()})")
elif f.name == 'date' and f.is_str() and all(re.match(r'[0-9]{4}-[0-9]{2}-[0-9]{2}', v) is not None for v in f.values):
clauses.append(f"s.date IN ({f.to_sql_list()})")
else:
raise BadRequestError(f"Invalid filter '{f}'")
return clauses
class ElwigApi(BaseHTTPRequestHandler):
def send(self, data: str, status_code: int = 200, url: str = None):
raw = data.encode('utf-8')
self.send_response(status_code)
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Headers', 'Authorization')
self.send_header('Access-Control-Allow-Methods', 'HEAD, GET, OPTIONS')
if 300 <= status_code < 400 and status_code != 304:
self.send_header('Location', url)
elif status_code == 401:
self.send_header('WWW-Authenticate', 'Basic realm=Elwig')
if self.headers.get('Accept-Encoding') and len(data) > 64:
accept_encoding = [e.strip() for e in self.headers.get('Accept-Encoding').split(',')]
if 'gzip' in accept_encoding:
raw = gzip.compress(raw)
self.send_header('Content-Encoding', 'gzip')
self.send_header('Content-Type', 'application/json; charset=UTF-8')
self.send_header('Content-Length', str(len(raw)))
self.end_headers()
if self.request.type != 'HEAD' and self.request.type != 'OPTIONS':
self.wfile.write(raw)
def error(self, status_code: int, message: str = None):
self.send(f'{{"message":{jdmp(message)}}}\n', status_code=status_code)
def see_other(self, url: str):
self.send(f'{{"url": {jdmp(url)}}}\n', status_code=303, url=url)
def authorize(self) -> (str or None, str or None):
try:
auth = self.headers.get('Authorization')
if auth is None or not auth.startswith('Basic '):
self.error(401, 'Unauthorized')
return None, None
auth = base64.b64decode(auth[6:]).split(b':', 1)
if len(auth) != 2:
self.error(401, 'Invalid Authorization header')
return None, None
username, password = auth[0].decode('utf-8'), auth[1].decode('utf-8')
except:
self.error(401, 'Invalid Authorization header')
return None, None
with open(USER_FILE, 'r') as file:
for line in file:
(u, r, p) = line.strip().split(':', 2)
if u == username:
if p == password:
return u, r
else:
self.error(401, 'Unauthorized')
return None, None
self.error(401, 'Unauthorized')
return None, None
def exec_collection(self, sql_query: str, fmt: Callable, filters: list[Filter],
offset: int = None, limit: int = None,
distinct: tuple[[str], [int]] = None, sub_fmt: Callable = None):
with_clause = re.findall(r'(WITH .*?\))[\s\n]*SELECT', sql_query, flags=re.DOTALL)
if len(with_clause) > 0 and '.*' not in with_clause[0]:
with_clause = with_clause[0]
count_query = sql_query.replace(with_clause, '')
else:
with_clause = None
count_query = sql_query
count = f"""SELECT COUNT(DISTINCT {" || '|' || ".join(distinct[0])}) FROM""" if distinct else "SELECT COUNT(*) FROM"
count_query = re.sub(r"SELECT [^*]+? FROM", count, count_query, count=1, flags=re.DOTALL)
count_query = re.sub(r"(OFFSET|LIMIT) [0-9-]+", '', count_query)
count_query = re.sub(r"GROUP BY .*", '', count_query)
if with_clause:
count_query = with_clause + ' ' + count_query
count = CNX.execute(count_query).fetchone()
count = count[0] if count is not None else 0
if limit is not None:
if "LIMIT " in sql_query:
sql_query = re.sub(r"LIMIT [0-9-]+", f"LIMIT {limit}", sql_query)
else:
sql_query += f" LIMIT {limit}"
if offset is not None:
if "OFFSET " in sql_query:
sql_query = re.sub(r"OFFSET [0-9-]+", f"OFFSET {offset}", sql_query)
else:
sql_query += f" OFFSET {offset}"
rows = CNX.execute(sql_query)
data = (f'''{{"filters":[{','.join(f'{{"filter":{jdmp(str(f))}}}' for f in filters)}],'''
f'"total":{count},"offset":{offset or 0},"limit":{jdmp(limit)},'
f'"data":[')
first, first_, cur, last = True, True, None, None
for r in rows or []:
cur = tuple([r[i] for i in distinct[1]]) if distinct else None
if not distinct or cur != last:
first_ = True
if first:
first = False
else:
if distinct and sub_fmt:
data += '\n ]}'
data += ','
data += f'\n ' + fmt(r)
if distinct and sub_fmt:
if first_:
data += '['
first_ = False
else:
data += ','
data += f'\n ' + sub_fmt(r)
last = cur
if distinct and sub_fmt and not first:
data += '\n ]}'
data += '\n]}\n'
self.send(data)
def do_GET_delivery_schedules(self, filters: list[Filter], offset: int = None, limit: int = None, order: str = None):
clauses = get_delivery_schedule_filter_clauses(filters)
sql = f"""
WITH announcements
AS (SELECT year, dsnr, SUM(weight) AS weight
FROM delivery_announcement
GROUP BY year, dsnr)
SELECT s.year, s.dsnr, s.date, s.description, s.max_weight, s.cancelled,
COALESCE(a.weight, 0) AS announced_weight,
COALESCE(SUM(p.weight), 0) AS delivered_weight,
STRFTIME('%Y-%m-%dT%H:%M:%SZ', DATETIME(s.ancmt_from, 'unixepoch')),
STRFTIME('%Y-%m-%dT%H:%M:%SZ', DATETIME(s.ancmt_to, 'unixepoch')),
b.zwstid, b.name,
s.attrid, s.cultid
FROM delivery_schedule s
LEFT JOIN branch b ON b.zwstid = s.zwstid
LEFT JOIN announcements a ON (a.year, a.dsnr) = (s.year, s.dsnr)
LEFT JOIN delivery_schedule_wine_variety v ON (v.year, v.dsnr) = (s.year, s.dsnr)
LEFT JOIN delivery d ON (d.date, d.zwstid) = (s.date, s.zwstid)
LEFT JOIN delivery_part p ON (p.year, p.did) = (d.year, d.did) AND p.sortid = v.sortid
"""
if len(clauses) > 0:
sql += f" WHERE {' AND '.join(clauses)}"
sql += " GROUP BY s.year, s.dsnr"
sql += " ORDER BY s.year, s.date, s.zwstid, s.description, s.dsnr"
rows1 = CNX.execute("""
SELECT date, zwstid, cultid, SUM(weight)
FROM delivery d
JOIN delivery_part p ON (p.year, p.did) = (d.year, d.did)
WHERE (d.date, d.zwstid, COALESCE(p.cultid, '')) IN
(SELECT date, zwstid, COALESCE(cultid, '') FROM delivery_schedule GROUP BY date, zwstid, cultid HAVING COUNT(*) = 1)
GROUP BY date, zwstid, cultid
""")
days1 = {(r[0], r[1], r[2]): r[3] for r in rows1}
rows2 = CNX.execute("""
SELECT date, zwstid, attrid, SUM(weight)
FROM delivery d
JOIN delivery_part p ON (p.year, p.did) = (d.year, d.did)
WHERE (d.date, d.zwstid, COALESCE(p.attrid, '')) IN
(SELECT date, zwstid, COALESCE(attrid, '') FROM delivery_schedule GROUP BY date, zwstid, attrid HAVING COUNT(*) = 1)
GROUP BY date, zwstid, attrid
""")
days2 = {(r[0], r[1], r[2]): r[3] for r in rows2}
self.exec_collection(
sql,
lambda r: f'{{"year":{r[0]:4},"dsnr":{r[1]:2},"date":"{r[2]}",'
f'"branch":{{"zwstid":{jdmp(r[10])},"name":{jdmp(r[11]):20}}},'
f'"description":{jdmp(r[3]):50},'
f'"max_weight":{jdmp(r[4]):>6},'
f'"is_cancelled":{jdmp(r[5], is_bool=True)},'
f'"announced_weight":{r[6]:6},'
f'"delivered_weight":{days1.get((r[2], r[10], r[13]), days2.get((r[2], r[10], r[12]), r[7] or 0)):6},'
f'"announcement_from":{jdmp(r[8])},'
f'"announcement_to":{jdmp(r[9])}}}',
filters, offset, limit, distinct=(['s.year', 's.dsnr'], [1, 2]))
def do_HEAD(self):
self.do_GET()
def do_OPTIONS(self):
self.send('')
def do_GET(self):
try:
if self.path == '/':
openapi_json = f'https://{self.headers.get("Host", "localhost")}/elwig/api/v1/openapi.json'
self.see_other(f'https://validator.swagger.io/?url={openapi_json}')
return
elif self.path == '/openapi.json':
self.send(OPEN_API_DOC)
return
username, role = self.authorize()
if not username or not role:
return
parts = self.path.split('?', 1)
if len(parts) == 1:
path, query = parts[0], {}
else:
path, query = parts[0], {urllib.parse.unquote(s[0]): urllib.parse.unquote(s[-1])
for s in [p.split('=', 1) for p in parts[1].split('&')]}
filters = [Filter.from_str(f) for f in (query['filters'].split(',') if 'filters' in query else [])]
try:
offset = int(query['offset']) if 'offset' in query else None
limit = int(query['limit']) if 'limit' in query else None
except ValueError:
raise BadRequestError('Invalid integer value in query')
order = query['order'] if 'order' in query else None
if path == '/wine/varieties':
self.exec_collection(
"SELECT sortid, type, name, comment FROM wine_variety",
lambda r: f'{{"sortid":{jdmp(r[0])},"type":{jdmp(r[1])},"name":{jdmp(r[2])},"comment":{jdmp(r[3])}}}',
[], offset, limit)
elif path == '/wine/quality_levels':
self.exec_collection(
"SELECT qualid, name, min_kmw, predicate FROM wine_quality_level",
lambda r: f'{{"qualid":{jdmp(r[0])},"name":{jdmp(r[1]):22},"min_kmw":{jdmp(r[2])},"is_predicate":{jdmp(r[3], is_bool=True)}}}',
[], offset, limit)
elif path == '/wine/attributes':
self.exec_collection(
"SELECT attrid, name FROM wine_attribute",
lambda r: f'{{"attrid":{jdmp(r[0]):4},"name":{jdmp(r[1])}}}',
[], offset, limit)
elif path == '/wine/cultivations':
self.exec_collection(
"SELECT cultid, name, description FROM wine_cultivation",
lambda r: f'{{"cultid":{jdmp(r[0]):5},"name":{jdmp(r[1])},"description":{jdmp(r[2])}}}',
[], offset, limit)
elif path == '/modifiers':
self.exec_collection(
"SELECT year, modid, name, ordering FROM modifier",
lambda r: f'{{"year":{jdmp(r[0])},"modid":{jdmp(r[1]):5},"name":{jdmp(r[2]):18},"ordering":{jdmp(r[3])}}}',
[], offset, limit)
elif path == '/delivery_schedules':
self.do_GET_delivery_schedules(filters, offset, limit, order)
else:
self.error(404, 'Invalid path')
except BadRequestError as e:
self.error(400, str(e))
except Exception as e:
traceback.print_exception(e)
self.error(500, str(e))
def main() -> None:
global CNX
global USER_FILE
sqlite3.register_adapter(datetime.date, lambda d: d.strftime('%Y-%m-%d'))
sqlite3.register_adapter(datetime.time, lambda t: t.strftime('%H:%M:%S'))
parser = argparse.ArgumentParser()
parser.add_argument('db', type=str, metavar='DB')
parser.add_argument('user_file', type=str, metavar='USER-FILE')
parser.add_argument('-p', '--port', type=int, default=8080)
args = parser.parse_args()
USER_FILE = args.user_file
CNX = sqlite3.connect(f'file:{args.db}?mode=ro', uri=True)
CNX.create_function('REGEXP', 2, sqlite_regexp, deterministic=True)
server = HTTPServer(('localhost', args.port), ElwigApi)
print(f'Listening on http://localhost:{args.port}')
try:
server.serve_forever()
except InterruptedError:
print()
except KeyboardInterrupt:
print()
server.server_close()
print('Good bye!')
if __name__ == '__main__': if __name__ == '__main__':
main() main()