diff --git a/src/elwig-backend b/src/elwig-backend index e59d5a7..a138c32 100755 --- a/src/elwig-backend +++ b/src/elwig-backend @@ -21,8 +21,32 @@ CNX: sqlite3.Cursor USER_FILE: str -class BadRequestError(Exception): - pass +class HttpError(Exception): + status_code: int + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + super().__init__(message) + + +class BadRequestError(HttpError): + def __init__(self, message: str = 'Bad request'): + super().__init__(400, message) + + +class UnauthorizedError(HttpError): + def __init__(self, message: str = 'Unauthorized'): + super().__init__(401, message) + + +class NotFoundError(HttpError): + def __init__(self, message: str = 'Not found'): + super().__init__(404, message) + + +class MethodNotAllowedError(HttpError): + def __init__(self, message: str = 'Method not allowed'): + super().__init__(405, message) class Filter: @@ -99,17 +123,17 @@ def get_delivery_schedule_filter_clauses(filters: list[Filter]) -> list[str]: class ElwigApi(BaseHTTPRequestHandler): - def send(self, data: str, status_code: int = 200, url: str = None): + def send(self, data: str, status_code: int = 200, url: str = None) -> None: raw = data.encode('utf-8') self.send_response(status_code) self.send_header('Access-Control-Allow-Origin', '*') self.send_header('Access-Control-Allow-Headers', 'Authorization') self.send_header('Access-Control-Allow-Methods', 'HEAD, GET, OPTIONS') - if 300 <= status_code < 400 and status_code != 304: + if 300 <= status_code < 400 and status_code != 304 and url: self.send_header('Location', url) elif status_code == 401: self.send_header('WWW-Authenticate', 'Basic realm=Elwig') - if self.headers.get('Accept-Encoding') and len(data) > 64: + if 'Accept-Encoding' in self.headers and len(data) > 64: accept_encoding = [e.strip() for e in self.headers.get('Accept-Encoding').split(',')] if 'gzip' in accept_encoding: raw = gzip.compress(raw) @@ -120,41 +144,34 @@ class ElwigApi(BaseHTTPRequestHandler): if self.request.type != 'HEAD' and self.request.type != 'OPTIONS': self.wfile.write(raw) - def error(self, status_code: int, message: str = None): + def error(self, status_code: int, message: str = None) -> None: self.send(f'{{"message":{jdmp(message)}}}\n', status_code=status_code) - def see_other(self, url: str): + def see_other(self, url: str) -> None: self.send(f'{{"url": {jdmp(url)}}}\n', status_code=303, url=url) - def authorize(self) -> (str or None, str or None): + def authorize(self) -> tuple[str, str]: try: auth = self.headers.get('Authorization') if auth is None or not auth.startswith('Basic '): - self.error(401, 'Unauthorized') - return None, None + raise UnauthorizedError() auth = base64.b64decode(auth[6:]).split(b':', 1) if len(auth) != 2: - self.error(401, 'Invalid Authorization header') - return None, None + raise UnauthorizedError('Invalid Authorization header') username, password = auth[0].decode('utf-8'), auth[1].decode('utf-8') except: - self.error(401, 'Invalid Authorization header') - return None, None + raise UnauthorizedError('Invalid Authorization header') with open(USER_FILE, 'r') as file: for line in file: (u, r, p) = line.strip().split(':', 2) if u == username: if p == password: return u, r - else: - self.error(401, 'Unauthorized') - return None, None - self.error(401, 'Unauthorized') - return None, None + raise UnauthorizedError() def exec_collection(self, sql_query: str, fmt: Callable, filters: list[Filter], offset: int = None, limit: int = None, - distinct: tuple[[str], [int]] = None, sub_fmt: Callable = None): + distinct: tuple[[str], [int]] = None, sub_fmt: Callable = None) -> None: with_clause = re.findall(r'(WITH .*?\))[\s\n]*SELECT', sql_query, flags=re.DOTALL) if len(with_clause) > 0 and '.*' not in with_clause[0]: with_clause = with_clause[0] @@ -211,7 +228,9 @@ class ElwigApi(BaseHTTPRequestHandler): data += '\n]}\n' self.send(data) - def do_GET_delivery_schedules(self, filters: list[Filter], offset: int = None, limit: int = None, order: str = None): + def do_GET_delivery_schedules(self, filters: list[Filter], + offset: int = None, limit: int = None, + order: str = None) -> None: clauses = get_delivery_schedule_filter_clauses(filters) sql = f""" WITH announcements @@ -269,13 +288,13 @@ class ElwigApi(BaseHTTPRequestHandler): f'"announcement_to":{jdmp(r[9])}}}', filters, offset, limit, distinct=(['s.year', 's.dsnr'], [1, 2])) - def do_HEAD(self): + def do_HEAD(self) -> None: self.do_GET() - def do_OPTIONS(self): + def do_OPTIONS(self) -> None: self.send('') - def do_GET(self): + def do_GET(self) -> None: try: if self.path == '/': openapi_json = f'https://{self.headers.get("Host", "localhost")}/elwig/api/v1/openapi.json' @@ -286,8 +305,6 @@ class ElwigApi(BaseHTTPRequestHandler): return username, role = self.authorize() - if not username or not role: - return parts = self.path.split('?', 1) if len(parts) == 1: @@ -331,9 +348,9 @@ class ElwigApi(BaseHTTPRequestHandler): elif path == '/delivery_schedules': self.do_GET_delivery_schedules(filters, offset, limit, order) else: - self.error(404, 'Invalid path') - except BadRequestError as e: - self.error(400, str(e)) + raise NotFoundError('Invalid path') + except HttpError as e: + self.error(e.status_code, str(e)) except Exception as e: traceback.print_exception(e) self.error(500, str(e)) @@ -360,9 +377,7 @@ def main() -> None: print(f'Listening on http://localhost:{args.port}') try: server.serve_forever() - except InterruptedError: - print() - except KeyboardInterrupt: + except InterruptedError | KeyboardInterrupt: print() server.server_close() print('Good bye!')