77 lines
2.8 KiB
Python
77 lines
2.8 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
import tornado
|
|
import src.model.trainer as model_trainer
|
|
from src.l.logger import logger
|
|
from src.transport.rabbitmq import RabbitMQ
|
|
from src.transport.train_dto import TrainDto
|
|
|
|
_spam_detector = None
|
|
|
|
|
|
def _json(data) -> str:
|
|
return json.dumps(data)
|
|
|
|
def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq: RabbitMQ, queue: str) -> None:
|
|
global _spam_detector
|
|
logger.info("Starting...")
|
|
|
|
def _create_spam_detector():
|
|
model_trainer.apply_latest(fucking_path=fucking_path, backup_path=backup_path)
|
|
from spam_detector_ai.prediction.predict import VotingSpamDetector
|
|
return VotingSpamDetector()
|
|
|
|
_spam_detector = _create_spam_detector()
|
|
|
|
class CheckSpamHandler(tornado.web.RequestHandler):
|
|
@tornado.gen.coroutine
|
|
def post(self):
|
|
body = json.loads(self.request.body)
|
|
if not 'text' in body:
|
|
self.set_status(400)
|
|
self.write_error(400, body=_json({ "error": "text is not specified" }))
|
|
else:
|
|
r = _json({ "is_spam": _spam_detector.is_spam(body['text']) })
|
|
self.write(r)
|
|
|
|
class AdminTrainHandler(tornado.web.RequestHandler):
|
|
@tornado.gen.coroutine
|
|
def post(self):
|
|
req = json.loads(self.request.body)
|
|
if not 'is_spam' in req or not 'text' in req:
|
|
self.set_status(400)
|
|
self.write(_json({ 'status': 'fail', 'message': 'wrong format' }))
|
|
else:
|
|
rabbitmq.publish(queue, TrainDto(is_spam=bool(req['is_spam']), text=req['text']).to_json())
|
|
|
|
class AdminRestartHandler(tornado.web.RequestHandler):
|
|
@tornado.gen.coroutine
|
|
def post(self):
|
|
global _spam_detector
|
|
auth_header = self.request.headers.get('Authorization')
|
|
if auth_header:
|
|
auth_token = auth_header.split(" ")[1]
|
|
else:
|
|
auth_token = ''
|
|
if auth_token == token:
|
|
_spam_detector = _create_spam_detector()
|
|
else:
|
|
self.set_status(403)
|
|
self.write(_json({ 'status': 'fail', 'message': 'Invalid authentication token' }))
|
|
|
|
async def start_web_server():
|
|
logger.info(f"Starting web server on port {port}")
|
|
app = tornado.web.Application(
|
|
[
|
|
(r"/check-spam", CheckSpamHandler),
|
|
(r"/admin/train", AdminTrainHandler),
|
|
(r"/admin/restart", AdminRestartHandler)
|
|
],
|
|
template_path=os.path.join(os.path.dirname(__file__), "templates"),
|
|
static_path=os.path.join(os.path.dirname(__file__), "static"),
|
|
)
|
|
app.listen(port)
|
|
await asyncio.Event().wait()
|
|
|
|
asyncio.run(start_web_server()) |