import asyncio import json import os import tornado import src.model.trainer as model_trainer from src.l.logger import logger from src.transport.rabbitmq import RabbitMQ from src.transport.train_dto import TrainDto _spam_detector = None def _json(data) -> str: return json.dumps(data) def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq: RabbitMQ, queue: str) -> None: global _spam_detector logger.info("Starting...") def _create_spam_detector(): model_trainer.apply_latest(fucking_path=fucking_path, backup_path=backup_path) from spam_detector_ai.prediction.predict import VotingSpamDetector return VotingSpamDetector() _spam_detector = _create_spam_detector() class CheckSpamHandler(tornado.web.RequestHandler): @tornado.gen.coroutine def post(self): body = json.loads(self.request.body) if not 'text' in body: self.set_status(400) self.write_error(400, body=_json({ "error": "text is not specified" })) else: r = _json({ "is_spam": _spam_detector.is_spam(body['text']) }) self.write(r) class AdminTrainHandler(tornado.web.RequestHandler): @tornado.gen.coroutine def post(self): req = json.loads(self.request.body) if not 'is_spam' in req or not 'text' in req: self.set_status(400) self.write(_json({ 'status': 'fail', 'message': 'wrong format' })) else: rabbitmq.publish(queue, TrainDto(is_spam=bool(req['is_spam']), text=req['text']).to_json()) class AdminRestartHandler(tornado.web.RequestHandler): @tornado.gen.coroutine def post(self): global _spam_detector auth_header = self.request.headers.get('Authorization') if auth_header: auth_token = auth_header.split(" ")[1] else: auth_token = '' if auth_token == token: _spam_detector = _create_spam_detector() else: self.set_status(403) self.write(_json({ 'status': 'fail', 'message': 'Invalid authentication token' })) async def start_web_server(): logger.info(f"Starting web server on port {port}") app = tornado.web.Application( [ (r"/check-spam", CheckSpamHandler), (r"/admin/train", AdminTrainHandler), (r"/admin/restart", AdminRestartHandler) ], template_path=os.path.join(os.path.dirname(__file__), "templates"), static_path=os.path.join(os.path.dirname(__file__), "static"), ) app.listen(port) await asyncio.Event().wait() asyncio.run(start_web_server())