spam-detector/src/web/server.py
2024-11-01 23:46:57 +03:00

77 lines
2.8 KiB
Python

import asyncio
import json
import os
import tornado
import src.model.trainer as model_trainer
from src.l.logger import logger
from src.transport.rabbitmq import RabbitMQ
from src.transport.train_dto import TrainDto
_spam_detector = None
def _json(data) -> str:
return json.dumps(data)
def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq: RabbitMQ, queue: str) -> None:
global _spam_detector
logger.info("Starting...")
def _create_spam_detector():
model_trainer.apply_latest(fucking_path=fucking_path, backup_path=backup_path)
from spam_detector_ai.prediction.predict import VotingSpamDetector
return VotingSpamDetector()
_spam_detector = _create_spam_detector()
class CheckSpamHandler(tornado.web.RequestHandler):
@tornado.gen.coroutine
def post(self):
body = json.loads(self.request.body)
if not 'text' in body:
self.set_status(400)
self.write_error(400, body=_json({ "error": "text is not specified" }))
else:
r = _json({ "is_spam": _spam_detector.is_spam(body['text']) })
self.write(r)
class AdminTrainHandler(tornado.web.RequestHandler):
@tornado.gen.coroutine
def post(self):
req = json.loads(self.request.body)
if not 'is_spam' in req or not 'text' in req:
self.set_status(400)
self.write(_json({ 'status': 'fail', 'message': 'wrong format' }))
else:
rabbitmq.publish(queue, TrainDto(is_spam=bool(req['is_spam']), text=req['text']).to_json())
class AdminRestartHandler(tornado.web.RequestHandler):
@tornado.gen.coroutine
def post(self):
global _spam_detector
auth_header = self.request.headers.get('Authorization')
if auth_header:
auth_token = auth_header.split(" ")[1]
else:
auth_token = ''
if auth_token == token:
_spam_detector = _create_spam_detector()
else:
self.set_status(403)
self.write(_json({ 'status': 'fail', 'message': 'Invalid authentication token' }))
async def start_web_server():
logger.info(f"Starting web server on port {port}")
app = tornado.web.Application(
[
(r"/check-spam", CheckSpamHandler),
(r"/admin/train", AdminTrainHandler),
(r"/admin/restart", AdminRestartHandler)
],
template_path=os.path.join(os.path.dirname(__file__), "templates"),
static_path=os.path.join(os.path.dirname(__file__), "static"),
)
app.listen(port)
await asyncio.Event().wait()
asyncio.run(start_web_server())