elwig-backend: Minor refactoring
This commit is contained in:
@ -21,8 +21,32 @@ CNX: sqlite3.Cursor
|
||||
USER_FILE: str
|
||||
|
||||
|
||||
class BadRequestError(Exception):
|
||||
pass
|
||||
class HttpError(Exception):
|
||||
status_code: int
|
||||
|
||||
def __init__(self, status_code: int, message: str):
|
||||
self.status_code = status_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BadRequestError(HttpError):
|
||||
def __init__(self, message: str = 'Bad request'):
|
||||
super().__init__(400, message)
|
||||
|
||||
|
||||
class UnauthorizedError(HttpError):
|
||||
def __init__(self, message: str = 'Unauthorized'):
|
||||
super().__init__(401, message)
|
||||
|
||||
|
||||
class NotFoundError(HttpError):
|
||||
def __init__(self, message: str = 'Not found'):
|
||||
super().__init__(404, message)
|
||||
|
||||
|
||||
class MethodNotAllowedError(HttpError):
|
||||
def __init__(self, message: str = 'Method not allowed'):
|
||||
super().__init__(405, message)
|
||||
|
||||
|
||||
class Filter:
|
||||
@ -99,17 +123,17 @@ def get_delivery_schedule_filter_clauses(filters: list[Filter]) -> list[str]:
|
||||
|
||||
|
||||
class ElwigApi(BaseHTTPRequestHandler):
|
||||
def send(self, data: str, status_code: int = 200, url: str = None):
|
||||
def send(self, data: str, status_code: int = 200, url: str = None) -> None:
|
||||
raw = data.encode('utf-8')
|
||||
self.send_response(status_code)
|
||||
self.send_header('Access-Control-Allow-Origin', '*')
|
||||
self.send_header('Access-Control-Allow-Headers', 'Authorization')
|
||||
self.send_header('Access-Control-Allow-Methods', 'HEAD, GET, OPTIONS')
|
||||
if 300 <= status_code < 400 and status_code != 304:
|
||||
if 300 <= status_code < 400 and status_code != 304 and url:
|
||||
self.send_header('Location', url)
|
||||
elif status_code == 401:
|
||||
self.send_header('WWW-Authenticate', 'Basic realm=Elwig')
|
||||
if self.headers.get('Accept-Encoding') and len(data) > 64:
|
||||
if 'Accept-Encoding' in self.headers and len(data) > 64:
|
||||
accept_encoding = [e.strip() for e in self.headers.get('Accept-Encoding').split(',')]
|
||||
if 'gzip' in accept_encoding:
|
||||
raw = gzip.compress(raw)
|
||||
@ -120,41 +144,34 @@ class ElwigApi(BaseHTTPRequestHandler):
|
||||
if self.request.type != 'HEAD' and self.request.type != 'OPTIONS':
|
||||
self.wfile.write(raw)
|
||||
|
||||
def error(self, status_code: int, message: str = None):
|
||||
def error(self, status_code: int, message: str = None) -> None:
|
||||
self.send(f'{{"message":{jdmp(message)}}}\n', status_code=status_code)
|
||||
|
||||
def see_other(self, url: str):
|
||||
def see_other(self, url: str) -> None:
|
||||
self.send(f'{{"url": {jdmp(url)}}}\n', status_code=303, url=url)
|
||||
|
||||
def authorize(self) -> (str or None, str or None):
|
||||
def authorize(self) -> tuple[str, str]:
|
||||
try:
|
||||
auth = self.headers.get('Authorization')
|
||||
if auth is None or not auth.startswith('Basic '):
|
||||
self.error(401, 'Unauthorized')
|
||||
return None, None
|
||||
raise UnauthorizedError()
|
||||
auth = base64.b64decode(auth[6:]).split(b':', 1)
|
||||
if len(auth) != 2:
|
||||
self.error(401, 'Invalid Authorization header')
|
||||
return None, None
|
||||
raise UnauthorizedError('Invalid Authorization header')
|
||||
username, password = auth[0].decode('utf-8'), auth[1].decode('utf-8')
|
||||
except:
|
||||
self.error(401, 'Invalid Authorization header')
|
||||
return None, None
|
||||
raise UnauthorizedError('Invalid Authorization header')
|
||||
with open(USER_FILE, 'r') as file:
|
||||
for line in file:
|
||||
(u, r, p) = line.strip().split(':', 2)
|
||||
if u == username:
|
||||
if p == password:
|
||||
return u, r
|
||||
else:
|
||||
self.error(401, 'Unauthorized')
|
||||
return None, None
|
||||
self.error(401, 'Unauthorized')
|
||||
return None, None
|
||||
raise UnauthorizedError()
|
||||
|
||||
def exec_collection(self, sql_query: str, fmt: Callable, filters: list[Filter],
|
||||
offset: int = None, limit: int = None,
|
||||
distinct: tuple[[str], [int]] = None, sub_fmt: Callable = None):
|
||||
distinct: tuple[[str], [int]] = None, sub_fmt: Callable = None) -> None:
|
||||
with_clause = re.findall(r'(WITH .*?\))[\s\n]*SELECT', sql_query, flags=re.DOTALL)
|
||||
if len(with_clause) > 0 and '.*' not in with_clause[0]:
|
||||
with_clause = with_clause[0]
|
||||
@ -211,7 +228,9 @@ class ElwigApi(BaseHTTPRequestHandler):
|
||||
data += '\n]}\n'
|
||||
self.send(data)
|
||||
|
||||
def do_GET_delivery_schedules(self, filters: list[Filter], offset: int = None, limit: int = None, order: str = None):
|
||||
def do_GET_delivery_schedules(self, filters: list[Filter],
|
||||
offset: int = None, limit: int = None,
|
||||
order: str = None) -> None:
|
||||
clauses = get_delivery_schedule_filter_clauses(filters)
|
||||
sql = f"""
|
||||
WITH announcements
|
||||
@ -269,13 +288,13 @@ class ElwigApi(BaseHTTPRequestHandler):
|
||||
f'"announcement_to":{jdmp(r[9])}}}',
|
||||
filters, offset, limit, distinct=(['s.year', 's.dsnr'], [1, 2]))
|
||||
|
||||
def do_HEAD(self):
|
||||
def do_HEAD(self) -> None:
|
||||
self.do_GET()
|
||||
|
||||
def do_OPTIONS(self):
|
||||
def do_OPTIONS(self) -> None:
|
||||
self.send('')
|
||||
|
||||
def do_GET(self):
|
||||
def do_GET(self) -> None:
|
||||
try:
|
||||
if self.path == '/':
|
||||
openapi_json = f'https://{self.headers.get("Host", "localhost")}/elwig/api/v1/openapi.json'
|
||||
@ -286,8 +305,6 @@ class ElwigApi(BaseHTTPRequestHandler):
|
||||
return
|
||||
|
||||
username, role = self.authorize()
|
||||
if not username or not role:
|
||||
return
|
||||
|
||||
parts = self.path.split('?', 1)
|
||||
if len(parts) == 1:
|
||||
@ -331,9 +348,9 @@ class ElwigApi(BaseHTTPRequestHandler):
|
||||
elif path == '/delivery_schedules':
|
||||
self.do_GET_delivery_schedules(filters, offset, limit, order)
|
||||
else:
|
||||
self.error(404, 'Invalid path')
|
||||
except BadRequestError as e:
|
||||
self.error(400, str(e))
|
||||
raise NotFoundError('Invalid path')
|
||||
except HttpError as e:
|
||||
self.error(e.status_code, str(e))
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
self.error(500, str(e))
|
||||
@ -360,9 +377,7 @@ def main() -> None:
|
||||
print(f'Listening on http://localhost:{args.port}')
|
||||
try:
|
||||
server.serve_forever()
|
||||
except InterruptedError:
|
||||
print()
|
||||
except KeyboardInterrupt:
|
||||
except InterruptedError | KeyboardInterrupt:
|
||||
print()
|
||||
server.server_close()
|
||||
print('Good bye!')
|
||||
|
Reference in New Issue
Block a user