elwig-backend: Minor refactoring

This commit is contained in:
2025-04-28 01:37:17 +02:00
parent aafea96faf
commit bc17b842a4

View File

@ -21,8 +21,32 @@ CNX: sqlite3.Cursor
USER_FILE: str USER_FILE: str
class BadRequestError(Exception): class HttpError(Exception):
pass status_code: int
def __init__(self, status_code: int, message: str):
self.status_code = status_code
super().__init__(message)
class BadRequestError(HttpError):
def __init__(self, message: str = 'Bad request'):
super().__init__(400, message)
class UnauthorizedError(HttpError):
def __init__(self, message: str = 'Unauthorized'):
super().__init__(401, message)
class NotFoundError(HttpError):
def __init__(self, message: str = 'Not found'):
super().__init__(404, message)
class MethodNotAllowedError(HttpError):
def __init__(self, message: str = 'Method not allowed'):
super().__init__(405, message)
class Filter: class Filter:
@ -99,17 +123,17 @@ def get_delivery_schedule_filter_clauses(filters: list[Filter]) -> list[str]:
class ElwigApi(BaseHTTPRequestHandler): class ElwigApi(BaseHTTPRequestHandler):
def send(self, data: str, status_code: int = 200, url: str = None): def send(self, data: str, status_code: int = 200, url: str = None) -> None:
raw = data.encode('utf-8') raw = data.encode('utf-8')
self.send_response(status_code) self.send_response(status_code)
self.send_header('Access-Control-Allow-Origin', '*') self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Headers', 'Authorization') self.send_header('Access-Control-Allow-Headers', 'Authorization')
self.send_header('Access-Control-Allow-Methods', 'HEAD, GET, OPTIONS') self.send_header('Access-Control-Allow-Methods', 'HEAD, GET, OPTIONS')
if 300 <= status_code < 400 and status_code != 304: if 300 <= status_code < 400 and status_code != 304 and url:
self.send_header('Location', url) self.send_header('Location', url)
elif status_code == 401: elif status_code == 401:
self.send_header('WWW-Authenticate', 'Basic realm=Elwig') self.send_header('WWW-Authenticate', 'Basic realm=Elwig')
if self.headers.get('Accept-Encoding') and len(data) > 64: if 'Accept-Encoding' in self.headers and len(data) > 64:
accept_encoding = [e.strip() for e in self.headers.get('Accept-Encoding').split(',')] accept_encoding = [e.strip() for e in self.headers.get('Accept-Encoding').split(',')]
if 'gzip' in accept_encoding: if 'gzip' in accept_encoding:
raw = gzip.compress(raw) raw = gzip.compress(raw)
@ -120,41 +144,34 @@ class ElwigApi(BaseHTTPRequestHandler):
if self.request.type != 'HEAD' and self.request.type != 'OPTIONS': if self.request.type != 'HEAD' and self.request.type != 'OPTIONS':
self.wfile.write(raw) self.wfile.write(raw)
def error(self, status_code: int, message: str = None): def error(self, status_code: int, message: str = None) -> None:
self.send(f'{{"message":{jdmp(message)}}}\n', status_code=status_code) self.send(f'{{"message":{jdmp(message)}}}\n', status_code=status_code)
def see_other(self, url: str): def see_other(self, url: str) -> None:
self.send(f'{{"url": {jdmp(url)}}}\n', status_code=303, url=url) self.send(f'{{"url": {jdmp(url)}}}\n', status_code=303, url=url)
def authorize(self) -> (str or None, str or None): def authorize(self) -> tuple[str, str]:
try: try:
auth = self.headers.get('Authorization') auth = self.headers.get('Authorization')
if auth is None or not auth.startswith('Basic '): if auth is None or not auth.startswith('Basic '):
self.error(401, 'Unauthorized') raise UnauthorizedError()
return None, None
auth = base64.b64decode(auth[6:]).split(b':', 1) auth = base64.b64decode(auth[6:]).split(b':', 1)
if len(auth) != 2: if len(auth) != 2:
self.error(401, 'Invalid Authorization header') raise UnauthorizedError('Invalid Authorization header')
return None, None
username, password = auth[0].decode('utf-8'), auth[1].decode('utf-8') username, password = auth[0].decode('utf-8'), auth[1].decode('utf-8')
except: except:
self.error(401, 'Invalid Authorization header') raise UnauthorizedError('Invalid Authorization header')
return None, None
with open(USER_FILE, 'r') as file: with open(USER_FILE, 'r') as file:
for line in file: for line in file:
(u, r, p) = line.strip().split(':', 2) (u, r, p) = line.strip().split(':', 2)
if u == username: if u == username:
if p == password: if p == password:
return u, r return u, r
else: raise UnauthorizedError()
self.error(401, 'Unauthorized')
return None, None
self.error(401, 'Unauthorized')
return None, None
def exec_collection(self, sql_query: str, fmt: Callable, filters: list[Filter], def exec_collection(self, sql_query: str, fmt: Callable, filters: list[Filter],
offset: int = None, limit: int = None, offset: int = None, limit: int = None,
distinct: tuple[[str], [int]] = None, sub_fmt: Callable = None): distinct: tuple[[str], [int]] = None, sub_fmt: Callable = None) -> None:
with_clause = re.findall(r'(WITH .*?\))[\s\n]*SELECT', sql_query, flags=re.DOTALL) with_clause = re.findall(r'(WITH .*?\))[\s\n]*SELECT', sql_query, flags=re.DOTALL)
if len(with_clause) > 0 and '.*' not in with_clause[0]: if len(with_clause) > 0 and '.*' not in with_clause[0]:
with_clause = with_clause[0] with_clause = with_clause[0]
@ -211,7 +228,9 @@ class ElwigApi(BaseHTTPRequestHandler):
data += '\n]}\n' data += '\n]}\n'
self.send(data) self.send(data)
def do_GET_delivery_schedules(self, filters: list[Filter], offset: int = None, limit: int = None, order: str = None): def do_GET_delivery_schedules(self, filters: list[Filter],
offset: int = None, limit: int = None,
order: str = None) -> None:
clauses = get_delivery_schedule_filter_clauses(filters) clauses = get_delivery_schedule_filter_clauses(filters)
sql = f""" sql = f"""
WITH announcements WITH announcements
@ -269,13 +288,13 @@ class ElwigApi(BaseHTTPRequestHandler):
f'"announcement_to":{jdmp(r[9])}}}', f'"announcement_to":{jdmp(r[9])}}}',
filters, offset, limit, distinct=(['s.year', 's.dsnr'], [1, 2])) filters, offset, limit, distinct=(['s.year', 's.dsnr'], [1, 2]))
def do_HEAD(self): def do_HEAD(self) -> None:
self.do_GET() self.do_GET()
def do_OPTIONS(self): def do_OPTIONS(self) -> None:
self.send('') self.send('')
def do_GET(self): def do_GET(self) -> None:
try: try:
if self.path == '/': if self.path == '/':
openapi_json = f'https://{self.headers.get("Host", "localhost")}/elwig/api/v1/openapi.json' openapi_json = f'https://{self.headers.get("Host", "localhost")}/elwig/api/v1/openapi.json'
@ -286,8 +305,6 @@ class ElwigApi(BaseHTTPRequestHandler):
return return
username, role = self.authorize() username, role = self.authorize()
if not username or not role:
return
parts = self.path.split('?', 1) parts = self.path.split('?', 1)
if len(parts) == 1: if len(parts) == 1:
@ -331,9 +348,9 @@ class ElwigApi(BaseHTTPRequestHandler):
elif path == '/delivery_schedules': elif path == '/delivery_schedules':
self.do_GET_delivery_schedules(filters, offset, limit, order) self.do_GET_delivery_schedules(filters, offset, limit, order)
else: else:
self.error(404, 'Invalid path') raise NotFoundError('Invalid path')
except BadRequestError as e: except HttpError as e:
self.error(400, str(e)) self.error(e.status_code, str(e))
except Exception as e: except Exception as e:
traceback.print_exception(e) traceback.print_exception(e)
self.error(500, str(e)) self.error(500, str(e))
@ -360,9 +377,7 @@ def main() -> None:
print(f'Listening on http://localhost:{args.port}') print(f'Listening on http://localhost:{args.port}')
try: try:
server.serve_forever() server.serve_forever()
except InterruptedError: except InterruptedError | KeyboardInterrupt:
print()
except KeyboardInterrupt:
print() print()
server.server_close() server.server_close()
print('Good bye!') print('Good bye!')