communication over rabbitmq

This commit is contained in:
bvn13 2024-11-01 23:35:45 +03:00
parent c6ab8791ca
commit 6c97e0c3f6
9 changed files with 155 additions and 15 deletions

View File

@ -40,13 +40,23 @@ ENV PYTHONPATH=/app
ARG DATASET
ARG PORT
ARG WEB_API_URL
ARG RABBITMQ_HOST
ARG RABBITMQ_PORT
ARG RABBITMQ_USER
ARG RABBITMQ_PASS
ARG RABBITMQ_QUEUE
ENV PORT=${PORT} \
DATASET=${DATASET} \
WORKING_DIR=/app/nltk_data \
FUCKING_DIR=/usr/local/lib/python3.10/site-packages/spam_detector_ai/models \
MODELS_DIR=/app/models \
WEB_API_URL=${WEB_API_URL}
WEB_API_URL=${WEB_API_URL} \
RABBITMQ_HOST=${RABBITMQ_HOST} \
RABBITMQ_PORT=${RABBITMQ_PORT} \
RABBITMQ_USER=${RABBITMQ_USER} \
RABBITMQ_PASS=${RABBITMQ_PASS} \
RABBITMQ_QUEUE=${RABBITMQ_QUEUE}
RUN python3 -m src.preparer

View File

@ -1,6 +1,8 @@
networks:
spam-detector-internal:
driver: bridge
rabbitmq:
external: true
services:
@ -11,6 +13,11 @@ services:
environment:
- PORT=${PORT:-8080}
- TOKEN=token12345
- RABBITMQ_HOST=${RABBITMQ_HOST}
- RABBITMQ_PORT=${RABBITMQ_PORT:-5672}
- RABBITMQ_USER=${RABBITMQ_USER}
- RABBITMQ_PASS=${RABBITMQ_PASS}
- RABBITMQ_QUEUE=${RABBITMQ_QUEUE}
entrypoint: [ "python", "-m", "src.app", "-m" ]
volumes:
- type: bind
@ -22,6 +29,7 @@ services:
- "8080:${PORT:-8080}"
networks:
- spam-detector-internal
- rabbitmq
model-updater:
build: ./
@ -29,6 +37,11 @@ services:
environment:
- WEB_API_URL=http://spam-detector-decision-maker:${PORT:-8080}
- TOKEN=token12345
- RABBITMQ_HOST=${RABBITMQ_HOST}
- RABBITMQ_PORT=${RABBITMQ_PORT:-5672}
- RABBITMQ_USER=${RABBITMQ_USER}
- RABBITMQ_PASS=${RABBITMQ_PASS}
- RABBITMQ_QUEUE=${RABBITMQ_QUEUE}
entrypoint: [ "python", "-m", "src.app", "-u", "-d", "/app/dataset.csv" ]
# entrypoint: [ "bash", "-c", "while true; do sleep 1; done" ]
volumes:
@ -38,3 +51,4 @@ services:
- ${MODELS_DIR}:/app/models/
networks:
- spam-detector-internal
- rabbitmq

View File

@ -24,6 +24,7 @@ tornado = "^6.4.1"
asyncio = "^3.4.3"
exceptiongroup = "^1.0.0rc8"
scheduler = "^0.8.7"
pika = "^1.3.2"
[tool.poetry.scripts]
app = "src.app:start"

View File

@ -135,9 +135,9 @@ h2==3.2.0 ; python_version >= "3.10" and python_version < "4.0" \
hpack==3.0.0 ; python_version >= "3.10" and python_version < "4.0" \
--hash=sha256:0edd79eda27a53ba5be2dfabf3b15780928a0dff6eb0c60a3d6767720e970c89 \
--hash=sha256:8eec9c1f4bfae3408a3f30500261f7e6a65912dc138526ea054f9ad98892e9d2
hstspreload==2024.10.1 ; python_version >= "3.10" and python_version < "4.0" \
--hash=sha256:2859a6b52253743ddddad468d8c9570ba650170ca49ac416336826915ee409b8 \
--hash=sha256:3ab481036cbdff095cb411dafe33ee7924492319cf6ddaf4e776a159537541b3
hstspreload==2024.11.1 ; python_version >= "3.10" and python_version < "4.0" \
--hash=sha256:1dc00fd6517284ec32ca0e0955bd5de9d1b1475c2ad196cb9e2933dc05a51d6e \
--hash=sha256:e0b18112e122e1cac8ca59c8972079f7c688912205f8c81f5ba7cb6c66e05dda
httpcore==0.9.1 ; python_version >= "3.10" and python_version < "4.0" \
--hash=sha256:9850fe97a166a794d7e920590d5ec49a05488884c9fc8b5dba8561effab0c2a0 \
--hash=sha256:ecc5949310d9dae4de64648a4ce529f86df1f232ce23dcfefe737c24d21dfbe9
@ -268,6 +268,9 @@ pandas==2.2.3 ; python_version >= "3.10" and python_version < "4.0" \
--hash=sha256:f00d1345d84d8c86a63e476bb4955e46458b304b9575dcf71102b5c705320015 \
--hash=sha256:f3a255b2c19987fbbe62a9dfd6cff7ff2aa9ccab3fc75218fd4b7530f01efa24 \
--hash=sha256:fffb8ae78d8af97f849404f21411c95062db1496aeb3e56f146f0355c9989319
pika==1.3.2 ; python_version >= "3.10" and python_version < "4.0" \
--hash=sha256:0779a7c1fafd805672796085560d290213a465e4f6f76a6fb19e378d8041a14f \
--hash=sha256:b2a327ddddf8570b4965b3576ac77091b850262d34ce8c1d8cb4e4146aa4145f
pluggy==1.5.0 ; python_version >= "3.10" and python_version < "4.0" \
--hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \
--hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669

View File

@ -3,7 +3,7 @@ import os
import src.model.trainer as model_trainer
import src.web.server as web_server
import src.model.updater as model_updater
from src.transport.rabbitmq import RabbitMQ
parser = argparse.ArgumentParser(prog='app.py')
parser.add_argument('-i', '--init', action=argparse.BooleanOptionalAction, help='Initializing, must be run beforehand, --dataset is required')
@ -21,6 +21,11 @@ _models_dir = os.getenv("MODELS_DIR")
_fucking_dir = os.getenv("FUCKING_DIR")
_web_api_url = os.getenv("WEB_API_URL")
_token = os.getenv("TOKEN")
_rabbitmq_host = os.getenv("RABBITMQ_HOST")
_rabbitmq_port = int(os.getenv("RABBITMQ_PORT"))
_rabbitmq_user = os.getenv("RABBITMQ_USER")
_rabbitmq_pass = os.getenv("RABBITMQ_PASS")
_rabbitmq_queue = os.getenv("RABBITMQ_QUEUE")
def start():
@ -28,11 +33,24 @@ def start():
assert args.dataset is not None, "Dataset is required, run --help"
model_trainer.train(args.dataset, fucking_path=_fucking_dir, backup_path=_models_dir)
elif args.decision_maker:
rabbitmq = RabbitMQ(_rabbitmq_host, _rabbitmq_port, _rabbitmq_user, _rabbitmq_pass)
model_trainer.apply_latest(fucking_path=_fucking_dir, backup_path=_models_dir)
web_server.start(port=_port, token=_token, fucking_path=_fucking_dir, backup_path=_models_dir)
web_server.start(port=_port,
token=_token,
fucking_path=_fucking_dir,
backup_path=_models_dir,
rabbitmq=rabbitmq,
queue=_rabbitmq_queue)
elif args.model_updater:
rabbitmq = RabbitMQ(_rabbitmq_host, _rabbitmq_port, _rabbitmq_user, _rabbitmq_pass)
assert args.dataset is not None, "Dataset is required, run --help"
model_updater.start(fucking_path=_fucking_dir, models_dir=_models_dir, dataset_path=args.dataset, web_api_url=_web_api_url, token=_token)
model_updater.start(fucking_path=_fucking_dir,
models_dir=_models_dir,
dataset_path=args.dataset,
web_api_url=_web_api_url,
token=_token,
rabbitmq=rabbitmq,
queue=_rabbitmq_queue)
if __name__ == '__main__':

View File

@ -1,3 +1,4 @@
import csv
import datetime as dt
from src.l.logger import logger
import src.model.trainer as trainer
@ -5,13 +6,24 @@ from scheduler import Scheduler
from os import listdir
import time
import requests
from src.transport.rabbitmq import RabbitMQ
from src.transport.train_dto import TrainDto
def _does_file_exist_in_dir(path):
return len(listdir(path)) > 0
def start(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str) -> None:
def _listen_to_trainings(csv_file: str, rabbitmq: RabbitMQ, queue: str) -> None:
def _callback(ch, method, properties, body):
dto = TrainDto.from_json(body)
with open(csv_file, "a") as f:
writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
writer.writerow([dto.is_spam, dto.text])
rabbitmq.consume(queue_name=queue, callback=_callback, auto_ack=True)
def start(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str, rabbitmq: RabbitMQ, queue: str) -> None:
logger.info("Starting...")
_listen_to_trainings(csv_file=dataset_path, rabbitmq=rabbitmq, queue=queue)
def _restart_web_api() -> None:
headers = { 'Authorization': f"Bearer {token}" }

48
src/transport/rabbitmq.py Normal file
View File

@ -0,0 +1,48 @@
from typing import Optional
import pika
import os
from pika.adapters.blocking_connection import BlockingChannel
class RabbitMQ:
def __init__(self, host: str, port: int, user: str, passwd: str):
self.host = host
self.port = port
self.user = user
self.password = passwd
self.connection = None
self.channel: Optional[BlockingChannel] = None
self.connect()
def connect(self):
credentials = pika.PlainCredentials(self.user, self.password)
parameters = pika.ConnectionParameters(host=self.host, port=self.port, credentials=credentials)
self.connection = pika.BlockingConnection(parameters)
self.channel = self.connection.channel()
def consumer_ack(self):
self.channel.basic_ack(delivery_tag=1, multiple=True)
def close(self):
if self.connection and not self.connection.is_closed:
self.connection.close()
def consume(self, queue_name, callback, auto_ack: bool = True):
if not self.channel:
raise Exception("Connection is not established.")
self.channel.basic_consume(queue=queue_name, on_message_callback=callback, auto_ack=auto_ack)
self.channel.start_consuming()
def publish(self, queue_name, message):
if not self.channel:
raise Exception("Connection is not established.")
self.channel.queue_declare(queue=queue_name, durable=True)
self.channel.basic_publish(exchange='',
routing_key=queue_name,
body=message,
properties=pika.BasicProperties(
delivery_mode=2, # make message persistent
))
print(f"Sent message to queue {queue_name}: {message}")

View File

@ -0,0 +1,23 @@
import json
from typing import TypeVar, Type
T = TypeVar('T', bound='TrainDto')
class TrainDto:
def __init__(self, is_spam: bool, text: str):
self.is_spam = is_spam
self.text = text
def to_json(self) -> str:
return json.dumps({
'is_spam': self.is_spam,
'text': self.text
})
@classmethod
def from_json(cls: Type[T], s: str) -> T:
j = json.loads(s)
if not 'is_spam' in j or not 'text' in j:
raise Exception("Wrong format")
return TrainDto(is_spam=j['is_spam'], text=j['text'])

View File

@ -4,7 +4,8 @@ import os
import tornado
import src.model.trainer as model_trainer
from src.l.logger import logger
from src.transport.rabbitmq import RabbitMQ
from src.transport.train_dto import TrainDto
_spam_detector = None
@ -12,7 +13,7 @@ _spam_detector = None
def _json(data) -> str:
return json.dumps(data)
def start(port: int, token: str, fucking_path: str, backup_path: str) -> None:
def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq: RabbitMQ, queue: str) -> None:
global _spam_detector
logger.info("Starting...")
@ -28,11 +29,22 @@ def start(port: int, token: str, fucking_path: str, backup_path: str) -> None:
def post(self):
body = json.loads(self.request.body)
if not 'text' in body:
self.write_error(400, body=_json({"error": "text is not specified"}))
self.set_status(400)
self.write_error(400, body=_json({ "error": "text is not specified" }))
else:
r = json.dumps({"is_spam": _spam_detector.is_spam(body['text'])})
r = _json({ "is_spam": _spam_detector.is_spam(body['text']) })
self.write(r)
class AdminTrainHandler(tornado.web.RequestHandler):
@tornado.gen.coroutine
def post(self):
req = json.loads(self.request.body)
if not 'decision' in req or not 'text' in req:
self.set_status(400)
self.write(_json({ 'status': 'fail', 'message': 'wrong format' }))
else:
rabbitmq.publish(queue, TrainDto(is_spam=bool(req['is_spam']), text=req['text']).to_json())
class AdminRestartHandler(tornado.web.RequestHandler):
@tornado.gen.coroutine
def post(self):
@ -46,15 +58,14 @@ def start(port: int, token: str, fucking_path: str, backup_path: str) -> None:
_spam_detector = _create_spam_detector()
else:
self.set_status(403)
self.write({'status': 'fail', 'message': 'Invalid authentication token'})
self.finish()
return
self.write(_json({ 'status': 'fail', 'message': 'Invalid authentication token' }))
async def start_web_server():
logger.info(f"Starting web server on port {port}")
app = tornado.web.Application(
[
(r"/check-spam", CheckSpamHandler),
(r"/admin/train", AdminTrainHandler),
(r"/admin/restart", AdminRestartHandler)
],
template_path=os.path.join(os.path.dirname(__file__), "templates"),