diff --git a/Dockerfile b/Dockerfile index 6975d47..6327b8c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -40,13 +40,23 @@ ENV PYTHONPATH=/app ARG DATASET ARG PORT ARG WEB_API_URL +ARG RABBITMQ_HOST +ARG RABBITMQ_PORT +ARG RABBITMQ_USER +ARG RABBITMQ_PASS +ARG RABBITMQ_QUEUE ENV PORT=${PORT} \ DATASET=${DATASET} \ WORKING_DIR=/app/nltk_data \ FUCKING_DIR=/usr/local/lib/python3.10/site-packages/spam_detector_ai/models \ MODELS_DIR=/app/models \ - WEB_API_URL=${WEB_API_URL} + WEB_API_URL=${WEB_API_URL} \ + RABBITMQ_HOST=${RABBITMQ_HOST} \ + RABBITMQ_PORT=${RABBITMQ_PORT} \ + RABBITMQ_USER=${RABBITMQ_USER} \ + RABBITMQ_PASS=${RABBITMQ_PASS} \ + RABBITMQ_QUEUE=${RABBITMQ_QUEUE} RUN python3 -m src.preparer diff --git a/docker-compose.yaml b/docker-compose.yaml index 6c21d70..394b6d3 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,6 +1,8 @@ networks: spam-detector-internal: driver: bridge + rabbitmq: + external: true services: @@ -11,6 +13,11 @@ services: environment: - PORT=${PORT:-8080} - TOKEN=token12345 + - RABBITMQ_HOST=${RABBITMQ_HOST} + - RABBITMQ_PORT=${RABBITMQ_PORT:-5672} + - RABBITMQ_USER=${RABBITMQ_USER} + - RABBITMQ_PASS=${RABBITMQ_PASS} + - RABBITMQ_QUEUE=${RABBITMQ_QUEUE} entrypoint: [ "python", "-m", "src.app", "-m" ] volumes: - type: bind @@ -22,6 +29,7 @@ services: - "8080:${PORT:-8080}" networks: - spam-detector-internal + - rabbitmq model-updater: build: ./ @@ -29,6 +37,11 @@ services: environment: - WEB_API_URL=http://spam-detector-decision-maker:${PORT:-8080} - TOKEN=token12345 + - RABBITMQ_HOST=${RABBITMQ_HOST} + - RABBITMQ_PORT=${RABBITMQ_PORT:-5672} + - RABBITMQ_USER=${RABBITMQ_USER} + - RABBITMQ_PASS=${RABBITMQ_PASS} + - RABBITMQ_QUEUE=${RABBITMQ_QUEUE} entrypoint: [ "python", "-m", "src.app", "-u", "-d", "/app/dataset.csv" ] # entrypoint: [ "bash", "-c", "while true; do sleep 1; done" ] volumes: @@ -38,3 +51,4 @@ services: - ${MODELS_DIR}:/app/models/ networks: - spam-detector-internal + - rabbitmq diff --git a/pyproject.toml b/pyproject.toml index 327ce9a..d9cca59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ tornado = "^6.4.1" asyncio = "^3.4.3" exceptiongroup = "^1.0.0rc8" scheduler = "^0.8.7" +pika = "^1.3.2" [tool.poetry.scripts] app = "src.app:start" diff --git a/requirements.txt b/requirements.txt index 0fbd1ed..d48b887 100644 --- a/requirements.txt +++ b/requirements.txt @@ -135,9 +135,9 @@ h2==3.2.0 ; python_version >= "3.10" and python_version < "4.0" \ hpack==3.0.0 ; python_version >= "3.10" and python_version < "4.0" \ --hash=sha256:0edd79eda27a53ba5be2dfabf3b15780928a0dff6eb0c60a3d6767720e970c89 \ --hash=sha256:8eec9c1f4bfae3408a3f30500261f7e6a65912dc138526ea054f9ad98892e9d2 -hstspreload==2024.10.1 ; python_version >= "3.10" and python_version < "4.0" \ - --hash=sha256:2859a6b52253743ddddad468d8c9570ba650170ca49ac416336826915ee409b8 \ - --hash=sha256:3ab481036cbdff095cb411dafe33ee7924492319cf6ddaf4e776a159537541b3 +hstspreload==2024.11.1 ; python_version >= "3.10" and python_version < "4.0" \ + --hash=sha256:1dc00fd6517284ec32ca0e0955bd5de9d1b1475c2ad196cb9e2933dc05a51d6e \ + --hash=sha256:e0b18112e122e1cac8ca59c8972079f7c688912205f8c81f5ba7cb6c66e05dda httpcore==0.9.1 ; python_version >= "3.10" and python_version < "4.0" \ --hash=sha256:9850fe97a166a794d7e920590d5ec49a05488884c9fc8b5dba8561effab0c2a0 \ --hash=sha256:ecc5949310d9dae4de64648a4ce529f86df1f232ce23dcfefe737c24d21dfbe9 @@ -268,6 +268,9 @@ pandas==2.2.3 ; python_version >= "3.10" and python_version < "4.0" \ --hash=sha256:f00d1345d84d8c86a63e476bb4955e46458b304b9575dcf71102b5c705320015 \ --hash=sha256:f3a255b2c19987fbbe62a9dfd6cff7ff2aa9ccab3fc75218fd4b7530f01efa24 \ --hash=sha256:fffb8ae78d8af97f849404f21411c95062db1496aeb3e56f146f0355c9989319 +pika==1.3.2 ; python_version >= "3.10" and python_version < "4.0" \ + --hash=sha256:0779a7c1fafd805672796085560d290213a465e4f6f76a6fb19e378d8041a14f \ + --hash=sha256:b2a327ddddf8570b4965b3576ac77091b850262d34ce8c1d8cb4e4146aa4145f pluggy==1.5.0 ; python_version >= "3.10" and python_version < "4.0" \ --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 diff --git a/src/app.py b/src/app.py index c651a8b..87f7f33 100644 --- a/src/app.py +++ b/src/app.py @@ -3,7 +3,7 @@ import os import src.model.trainer as model_trainer import src.web.server as web_server import src.model.updater as model_updater - +from src.transport.rabbitmq import RabbitMQ parser = argparse.ArgumentParser(prog='app.py') parser.add_argument('-i', '--init', action=argparse.BooleanOptionalAction, help='Initializing, must be run beforehand, --dataset is required') @@ -21,6 +21,11 @@ _models_dir = os.getenv("MODELS_DIR") _fucking_dir = os.getenv("FUCKING_DIR") _web_api_url = os.getenv("WEB_API_URL") _token = os.getenv("TOKEN") +_rabbitmq_host = os.getenv("RABBITMQ_HOST") +_rabbitmq_port = int(os.getenv("RABBITMQ_PORT")) +_rabbitmq_user = os.getenv("RABBITMQ_USER") +_rabbitmq_pass = os.getenv("RABBITMQ_PASS") +_rabbitmq_queue = os.getenv("RABBITMQ_QUEUE") def start(): @@ -28,11 +33,24 @@ def start(): assert args.dataset is not None, "Dataset is required, run --help" model_trainer.train(args.dataset, fucking_path=_fucking_dir, backup_path=_models_dir) elif args.decision_maker: + rabbitmq = RabbitMQ(_rabbitmq_host, _rabbitmq_port, _rabbitmq_user, _rabbitmq_pass) model_trainer.apply_latest(fucking_path=_fucking_dir, backup_path=_models_dir) - web_server.start(port=_port, token=_token, fucking_path=_fucking_dir, backup_path=_models_dir) + web_server.start(port=_port, + token=_token, + fucking_path=_fucking_dir, + backup_path=_models_dir, + rabbitmq=rabbitmq, + queue=_rabbitmq_queue) elif args.model_updater: + rabbitmq = RabbitMQ(_rabbitmq_host, _rabbitmq_port, _rabbitmq_user, _rabbitmq_pass) assert args.dataset is not None, "Dataset is required, run --help" - model_updater.start(fucking_path=_fucking_dir, models_dir=_models_dir, dataset_path=args.dataset, web_api_url=_web_api_url, token=_token) + model_updater.start(fucking_path=_fucking_dir, + models_dir=_models_dir, + dataset_path=args.dataset, + web_api_url=_web_api_url, + token=_token, + rabbitmq=rabbitmq, + queue=_rabbitmq_queue) if __name__ == '__main__': diff --git a/src/model/updater.py b/src/model/updater.py index 88d6f98..cb76d51 100644 --- a/src/model/updater.py +++ b/src/model/updater.py @@ -1,3 +1,4 @@ +import csv import datetime as dt from src.l.logger import logger import src.model.trainer as trainer @@ -5,13 +6,24 @@ from scheduler import Scheduler from os import listdir import time import requests +from src.transport.rabbitmq import RabbitMQ +from src.transport.train_dto import TrainDto def _does_file_exist_in_dir(path): return len(listdir(path)) > 0 -def start(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str) -> None: +def _listen_to_trainings(csv_file: str, rabbitmq: RabbitMQ, queue: str) -> None: + def _callback(ch, method, properties, body): + dto = TrainDto.from_json(body) + with open(csv_file, "a") as f: + writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) + writer.writerow([dto.is_spam, dto.text]) + rabbitmq.consume(queue_name=queue, callback=_callback, auto_ack=True) + +def start(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str, rabbitmq: RabbitMQ, queue: str) -> None: logger.info("Starting...") + _listen_to_trainings(csv_file=dataset_path, rabbitmq=rabbitmq, queue=queue) def _restart_web_api() -> None: headers = { 'Authorization': f"Bearer {token}" } diff --git a/src/transport/rabbitmq.py b/src/transport/rabbitmq.py new file mode 100644 index 0000000..16857a7 --- /dev/null +++ b/src/transport/rabbitmq.py @@ -0,0 +1,48 @@ +from typing import Optional + +import pika +import os + +from pika.adapters.blocking_connection import BlockingChannel + + +class RabbitMQ: + def __init__(self, host: str, port: int, user: str, passwd: str): + self.host = host + self.port = port + self.user = user + self.password = passwd + self.connection = None + self.channel: Optional[BlockingChannel] = None + self.connect() + + def connect(self): + credentials = pika.PlainCredentials(self.user, self.password) + parameters = pika.ConnectionParameters(host=self.host, port=self.port, credentials=credentials) + self.connection = pika.BlockingConnection(parameters) + self.channel = self.connection.channel() + + def consumer_ack(self): + self.channel.basic_ack(delivery_tag=1, multiple=True) + + def close(self): + if self.connection and not self.connection.is_closed: + self.connection.close() + + def consume(self, queue_name, callback, auto_ack: bool = True): + if not self.channel: + raise Exception("Connection is not established.") + self.channel.basic_consume(queue=queue_name, on_message_callback=callback, auto_ack=auto_ack) + self.channel.start_consuming() + + def publish(self, queue_name, message): + if not self.channel: + raise Exception("Connection is not established.") + self.channel.queue_declare(queue=queue_name, durable=True) + self.channel.basic_publish(exchange='', + routing_key=queue_name, + body=message, + properties=pika.BasicProperties( + delivery_mode=2, # make message persistent + )) + print(f"Sent message to queue {queue_name}: {message}") \ No newline at end of file diff --git a/src/transport/train_dto.py b/src/transport/train_dto.py new file mode 100644 index 0000000..eaf608c --- /dev/null +++ b/src/transport/train_dto.py @@ -0,0 +1,23 @@ +import json +from typing import TypeVar, Type + + +T = TypeVar('T', bound='TrainDto') + +class TrainDto: + def __init__(self, is_spam: bool, text: str): + self.is_spam = is_spam + self.text = text + + def to_json(self) -> str: + return json.dumps({ + 'is_spam': self.is_spam, + 'text': self.text + }) + + @classmethod + def from_json(cls: Type[T], s: str) -> T: + j = json.loads(s) + if not 'is_spam' in j or not 'text' in j: + raise Exception("Wrong format") + return TrainDto(is_spam=j['is_spam'], text=j['text']) diff --git a/src/web/server.py b/src/web/server.py index f2be45b..5880af0 100644 --- a/src/web/server.py +++ b/src/web/server.py @@ -4,7 +4,8 @@ import os import tornado import src.model.trainer as model_trainer from src.l.logger import logger - +from src.transport.rabbitmq import RabbitMQ +from src.transport.train_dto import TrainDto _spam_detector = None @@ -12,7 +13,7 @@ _spam_detector = None def _json(data) -> str: return json.dumps(data) -def start(port: int, token: str, fucking_path: str, backup_path: str) -> None: +def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq: RabbitMQ, queue: str) -> None: global _spam_detector logger.info("Starting...") @@ -28,11 +29,22 @@ def start(port: int, token: str, fucking_path: str, backup_path: str) -> None: def post(self): body = json.loads(self.request.body) if not 'text' in body: - self.write_error(400, body=_json({"error": "text is not specified"})) + self.set_status(400) + self.write_error(400, body=_json({ "error": "text is not specified" })) else: - r = json.dumps({"is_spam": _spam_detector.is_spam(body['text'])}) + r = _json({ "is_spam": _spam_detector.is_spam(body['text']) }) self.write(r) + class AdminTrainHandler(tornado.web.RequestHandler): + @tornado.gen.coroutine + def post(self): + req = json.loads(self.request.body) + if not 'decision' in req or not 'text' in req: + self.set_status(400) + self.write(_json({ 'status': 'fail', 'message': 'wrong format' })) + else: + rabbitmq.publish(queue, TrainDto(is_spam=bool(req['is_spam']), text=req['text']).to_json()) + class AdminRestartHandler(tornado.web.RequestHandler): @tornado.gen.coroutine def post(self): @@ -46,15 +58,14 @@ def start(port: int, token: str, fucking_path: str, backup_path: str) -> None: _spam_detector = _create_spam_detector() else: self.set_status(403) - self.write({'status': 'fail', 'message': 'Invalid authentication token'}) - self.finish() - return + self.write(_json({ 'status': 'fail', 'message': 'Invalid authentication token' })) async def start_web_server(): logger.info(f"Starting web server on port {port}") app = tornado.web.Application( [ (r"/check-spam", CheckSpamHandler), + (r"/admin/train", AdminTrainHandler), (r"/admin/restart", AdminRestartHandler) ], template_path=os.path.join(os.path.dirname(__file__), "templates"),