elwig-backend: Minor refactoring
This commit is contained in:
@ -21,8 +21,32 @@ CNX: sqlite3.Cursor
|
|||||||
USER_FILE: str
|
USER_FILE: str
|
||||||
|
|
||||||
|
|
||||||
class BadRequestError(Exception):
|
class HttpError(Exception):
|
||||||
pass
|
status_code: int
|
||||||
|
|
||||||
|
def __init__(self, status_code: int, message: str):
|
||||||
|
self.status_code = status_code
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BadRequestError(HttpError):
|
||||||
|
def __init__(self, message: str = 'Bad request'):
|
||||||
|
super().__init__(400, message)
|
||||||
|
|
||||||
|
|
||||||
|
class UnauthorizedError(HttpError):
|
||||||
|
def __init__(self, message: str = 'Unauthorized'):
|
||||||
|
super().__init__(401, message)
|
||||||
|
|
||||||
|
|
||||||
|
class NotFoundError(HttpError):
|
||||||
|
def __init__(self, message: str = 'Not found'):
|
||||||
|
super().__init__(404, message)
|
||||||
|
|
||||||
|
|
||||||
|
class MethodNotAllowedError(HttpError):
|
||||||
|
def __init__(self, message: str = 'Method not allowed'):
|
||||||
|
super().__init__(405, message)
|
||||||
|
|
||||||
|
|
||||||
class Filter:
|
class Filter:
|
||||||
@ -99,17 +123,17 @@ def get_delivery_schedule_filter_clauses(filters: list[Filter]) -> list[str]:
|
|||||||
|
|
||||||
|
|
||||||
class ElwigApi(BaseHTTPRequestHandler):
|
class ElwigApi(BaseHTTPRequestHandler):
|
||||||
def send(self, data: str, status_code: int = 200, url: str = None):
|
def send(self, data: str, status_code: int = 200, url: str = None) -> None:
|
||||||
raw = data.encode('utf-8')
|
raw = data.encode('utf-8')
|
||||||
self.send_response(status_code)
|
self.send_response(status_code)
|
||||||
self.send_header('Access-Control-Allow-Origin', '*')
|
self.send_header('Access-Control-Allow-Origin', '*')
|
||||||
self.send_header('Access-Control-Allow-Headers', 'Authorization')
|
self.send_header('Access-Control-Allow-Headers', 'Authorization')
|
||||||
self.send_header('Access-Control-Allow-Methods', 'HEAD, GET, OPTIONS')
|
self.send_header('Access-Control-Allow-Methods', 'HEAD, GET, OPTIONS')
|
||||||
if 300 <= status_code < 400 and status_code != 304:
|
if 300 <= status_code < 400 and status_code != 304 and url:
|
||||||
self.send_header('Location', url)
|
self.send_header('Location', url)
|
||||||
elif status_code == 401:
|
elif status_code == 401:
|
||||||
self.send_header('WWW-Authenticate', 'Basic realm=Elwig')
|
self.send_header('WWW-Authenticate', 'Basic realm=Elwig')
|
||||||
if self.headers.get('Accept-Encoding') and len(data) > 64:
|
if 'Accept-Encoding' in self.headers and len(data) > 64:
|
||||||
accept_encoding = [e.strip() for e in self.headers.get('Accept-Encoding').split(',')]
|
accept_encoding = [e.strip() for e in self.headers.get('Accept-Encoding').split(',')]
|
||||||
if 'gzip' in accept_encoding:
|
if 'gzip' in accept_encoding:
|
||||||
raw = gzip.compress(raw)
|
raw = gzip.compress(raw)
|
||||||
@ -120,41 +144,34 @@ class ElwigApi(BaseHTTPRequestHandler):
|
|||||||
if self.request.type != 'HEAD' and self.request.type != 'OPTIONS':
|
if self.request.type != 'HEAD' and self.request.type != 'OPTIONS':
|
||||||
self.wfile.write(raw)
|
self.wfile.write(raw)
|
||||||
|
|
||||||
def error(self, status_code: int, message: str = None):
|
def error(self, status_code: int, message: str = None) -> None:
|
||||||
self.send(f'{{"message":{jdmp(message)}}}\n', status_code=status_code)
|
self.send(f'{{"message":{jdmp(message)}}}\n', status_code=status_code)
|
||||||
|
|
||||||
def see_other(self, url: str):
|
def see_other(self, url: str) -> None:
|
||||||
self.send(f'{{"url": {jdmp(url)}}}\n', status_code=303, url=url)
|
self.send(f'{{"url": {jdmp(url)}}}\n', status_code=303, url=url)
|
||||||
|
|
||||||
def authorize(self) -> (str or None, str or None):
|
def authorize(self) -> tuple[str, str]:
|
||||||
try:
|
try:
|
||||||
auth = self.headers.get('Authorization')
|
auth = self.headers.get('Authorization')
|
||||||
if auth is None or not auth.startswith('Basic '):
|
if auth is None or not auth.startswith('Basic '):
|
||||||
self.error(401, 'Unauthorized')
|
raise UnauthorizedError()
|
||||||
return None, None
|
|
||||||
auth = base64.b64decode(auth[6:]).split(b':', 1)
|
auth = base64.b64decode(auth[6:]).split(b':', 1)
|
||||||
if len(auth) != 2:
|
if len(auth) != 2:
|
||||||
self.error(401, 'Invalid Authorization header')
|
raise UnauthorizedError('Invalid Authorization header')
|
||||||
return None, None
|
|
||||||
username, password = auth[0].decode('utf-8'), auth[1].decode('utf-8')
|
username, password = auth[0].decode('utf-8'), auth[1].decode('utf-8')
|
||||||
except:
|
except:
|
||||||
self.error(401, 'Invalid Authorization header')
|
raise UnauthorizedError('Invalid Authorization header')
|
||||||
return None, None
|
|
||||||
with open(USER_FILE, 'r') as file:
|
with open(USER_FILE, 'r') as file:
|
||||||
for line in file:
|
for line in file:
|
||||||
(u, r, p) = line.strip().split(':', 2)
|
(u, r, p) = line.strip().split(':', 2)
|
||||||
if u == username:
|
if u == username:
|
||||||
if p == password:
|
if p == password:
|
||||||
return u, r
|
return u, r
|
||||||
else:
|
raise UnauthorizedError()
|
||||||
self.error(401, 'Unauthorized')
|
|
||||||
return None, None
|
|
||||||
self.error(401, 'Unauthorized')
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
def exec_collection(self, sql_query: str, fmt: Callable, filters: list[Filter],
|
def exec_collection(self, sql_query: str, fmt: Callable, filters: list[Filter],
|
||||||
offset: int = None, limit: int = None,
|
offset: int = None, limit: int = None,
|
||||||
distinct: tuple[[str], [int]] = None, sub_fmt: Callable = None):
|
distinct: tuple[[str], [int]] = None, sub_fmt: Callable = None) -> None:
|
||||||
with_clause = re.findall(r'(WITH .*?\))[\s\n]*SELECT', sql_query, flags=re.DOTALL)
|
with_clause = re.findall(r'(WITH .*?\))[\s\n]*SELECT', sql_query, flags=re.DOTALL)
|
||||||
if len(with_clause) > 0 and '.*' not in with_clause[0]:
|
if len(with_clause) > 0 and '.*' not in with_clause[0]:
|
||||||
with_clause = with_clause[0]
|
with_clause = with_clause[0]
|
||||||
@ -211,7 +228,9 @@ class ElwigApi(BaseHTTPRequestHandler):
|
|||||||
data += '\n]}\n'
|
data += '\n]}\n'
|
||||||
self.send(data)
|
self.send(data)
|
||||||
|
|
||||||
def do_GET_delivery_schedules(self, filters: list[Filter], offset: int = None, limit: int = None, order: str = None):
|
def do_GET_delivery_schedules(self, filters: list[Filter],
|
||||||
|
offset: int = None, limit: int = None,
|
||||||
|
order: str = None) -> None:
|
||||||
clauses = get_delivery_schedule_filter_clauses(filters)
|
clauses = get_delivery_schedule_filter_clauses(filters)
|
||||||
sql = f"""
|
sql = f"""
|
||||||
WITH announcements
|
WITH announcements
|
||||||
@ -269,13 +288,13 @@ class ElwigApi(BaseHTTPRequestHandler):
|
|||||||
f'"announcement_to":{jdmp(r[9])}}}',
|
f'"announcement_to":{jdmp(r[9])}}}',
|
||||||
filters, offset, limit, distinct=(['s.year', 's.dsnr'], [1, 2]))
|
filters, offset, limit, distinct=(['s.year', 's.dsnr'], [1, 2]))
|
||||||
|
|
||||||
def do_HEAD(self):
|
def do_HEAD(self) -> None:
|
||||||
self.do_GET()
|
self.do_GET()
|
||||||
|
|
||||||
def do_OPTIONS(self):
|
def do_OPTIONS(self) -> None:
|
||||||
self.send('')
|
self.send('')
|
||||||
|
|
||||||
def do_GET(self):
|
def do_GET(self) -> None:
|
||||||
try:
|
try:
|
||||||
if self.path == '/':
|
if self.path == '/':
|
||||||
openapi_json = f'https://{self.headers.get("Host", "localhost")}/elwig/api/v1/openapi.json'
|
openapi_json = f'https://{self.headers.get("Host", "localhost")}/elwig/api/v1/openapi.json'
|
||||||
@ -286,8 +305,6 @@ class ElwigApi(BaseHTTPRequestHandler):
|
|||||||
return
|
return
|
||||||
|
|
||||||
username, role = self.authorize()
|
username, role = self.authorize()
|
||||||
if not username or not role:
|
|
||||||
return
|
|
||||||
|
|
||||||
parts = self.path.split('?', 1)
|
parts = self.path.split('?', 1)
|
||||||
if len(parts) == 1:
|
if len(parts) == 1:
|
||||||
@ -331,9 +348,9 @@ class ElwigApi(BaseHTTPRequestHandler):
|
|||||||
elif path == '/delivery_schedules':
|
elif path == '/delivery_schedules':
|
||||||
self.do_GET_delivery_schedules(filters, offset, limit, order)
|
self.do_GET_delivery_schedules(filters, offset, limit, order)
|
||||||
else:
|
else:
|
||||||
self.error(404, 'Invalid path')
|
raise NotFoundError('Invalid path')
|
||||||
except BadRequestError as e:
|
except HttpError as e:
|
||||||
self.error(400, str(e))
|
self.error(e.status_code, str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exception(e)
|
traceback.print_exception(e)
|
||||||
self.error(500, str(e))
|
self.error(500, str(e))
|
||||||
@ -360,9 +377,7 @@ def main() -> None:
|
|||||||
print(f'Listening on http://localhost:{args.port}')
|
print(f'Listening on http://localhost:{args.port}')
|
||||||
try:
|
try:
|
||||||
server.serve_forever()
|
server.serve_forever()
|
||||||
except InterruptedError:
|
except InterruptedError | KeyboardInterrupt:
|
||||||
print()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print()
|
print()
|
||||||
server.server_close()
|
server.server_close()
|
||||||
print('Good bye!')
|
print('Good bye!')
|
||||||
|
Reference in New Issue
Block a user