communication over rabbitmq
This commit is contained in:
parent
c6ab8791ca
commit
6c97e0c3f6
12
Dockerfile
12
Dockerfile
@ -40,13 +40,23 @@ ENV PYTHONPATH=/app
|
||||
ARG DATASET
|
||||
ARG PORT
|
||||
ARG WEB_API_URL
|
||||
ARG RABBITMQ_HOST
|
||||
ARG RABBITMQ_PORT
|
||||
ARG RABBITMQ_USER
|
||||
ARG RABBITMQ_PASS
|
||||
ARG RABBITMQ_QUEUE
|
||||
|
||||
ENV PORT=${PORT} \
|
||||
DATASET=${DATASET} \
|
||||
WORKING_DIR=/app/nltk_data \
|
||||
FUCKING_DIR=/usr/local/lib/python3.10/site-packages/spam_detector_ai/models \
|
||||
MODELS_DIR=/app/models \
|
||||
WEB_API_URL=${WEB_API_URL}
|
||||
WEB_API_URL=${WEB_API_URL} \
|
||||
RABBITMQ_HOST=${RABBITMQ_HOST} \
|
||||
RABBITMQ_PORT=${RABBITMQ_PORT} \
|
||||
RABBITMQ_USER=${RABBITMQ_USER} \
|
||||
RABBITMQ_PASS=${RABBITMQ_PASS} \
|
||||
RABBITMQ_QUEUE=${RABBITMQ_QUEUE}
|
||||
|
||||
RUN python3 -m src.preparer
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
networks:
|
||||
spam-detector-internal:
|
||||
driver: bridge
|
||||
rabbitmq:
|
||||
external: true
|
||||
|
||||
|
||||
services:
|
||||
@ -11,6 +13,11 @@ services:
|
||||
environment:
|
||||
- PORT=${PORT:-8080}
|
||||
- TOKEN=token12345
|
||||
- RABBITMQ_HOST=${RABBITMQ_HOST}
|
||||
- RABBITMQ_PORT=${RABBITMQ_PORT:-5672}
|
||||
- RABBITMQ_USER=${RABBITMQ_USER}
|
||||
- RABBITMQ_PASS=${RABBITMQ_PASS}
|
||||
- RABBITMQ_QUEUE=${RABBITMQ_QUEUE}
|
||||
entrypoint: [ "python", "-m", "src.app", "-m" ]
|
||||
volumes:
|
||||
- type: bind
|
||||
@ -22,6 +29,7 @@ services:
|
||||
- "8080:${PORT:-8080}"
|
||||
networks:
|
||||
- spam-detector-internal
|
||||
- rabbitmq
|
||||
|
||||
model-updater:
|
||||
build: ./
|
||||
@ -29,6 +37,11 @@ services:
|
||||
environment:
|
||||
- WEB_API_URL=http://spam-detector-decision-maker:${PORT:-8080}
|
||||
- TOKEN=token12345
|
||||
- RABBITMQ_HOST=${RABBITMQ_HOST}
|
||||
- RABBITMQ_PORT=${RABBITMQ_PORT:-5672}
|
||||
- RABBITMQ_USER=${RABBITMQ_USER}
|
||||
- RABBITMQ_PASS=${RABBITMQ_PASS}
|
||||
- RABBITMQ_QUEUE=${RABBITMQ_QUEUE}
|
||||
entrypoint: [ "python", "-m", "src.app", "-u", "-d", "/app/dataset.csv" ]
|
||||
# entrypoint: [ "bash", "-c", "while true; do sleep 1; done" ]
|
||||
volumes:
|
||||
@ -38,3 +51,4 @@ services:
|
||||
- ${MODELS_DIR}:/app/models/
|
||||
networks:
|
||||
- spam-detector-internal
|
||||
- rabbitmq
|
||||
|
@ -24,6 +24,7 @@ tornado = "^6.4.1"
|
||||
asyncio = "^3.4.3"
|
||||
exceptiongroup = "^1.0.0rc8"
|
||||
scheduler = "^0.8.7"
|
||||
pika = "^1.3.2"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
app = "src.app:start"
|
||||
|
@ -135,9 +135,9 @@ h2==3.2.0 ; python_version >= "3.10" and python_version < "4.0" \
|
||||
hpack==3.0.0 ; python_version >= "3.10" and python_version < "4.0" \
|
||||
--hash=sha256:0edd79eda27a53ba5be2dfabf3b15780928a0dff6eb0c60a3d6767720e970c89 \
|
||||
--hash=sha256:8eec9c1f4bfae3408a3f30500261f7e6a65912dc138526ea054f9ad98892e9d2
|
||||
hstspreload==2024.10.1 ; python_version >= "3.10" and python_version < "4.0" \
|
||||
--hash=sha256:2859a6b52253743ddddad468d8c9570ba650170ca49ac416336826915ee409b8 \
|
||||
--hash=sha256:3ab481036cbdff095cb411dafe33ee7924492319cf6ddaf4e776a159537541b3
|
||||
hstspreload==2024.11.1 ; python_version >= "3.10" and python_version < "4.0" \
|
||||
--hash=sha256:1dc00fd6517284ec32ca0e0955bd5de9d1b1475c2ad196cb9e2933dc05a51d6e \
|
||||
--hash=sha256:e0b18112e122e1cac8ca59c8972079f7c688912205f8c81f5ba7cb6c66e05dda
|
||||
httpcore==0.9.1 ; python_version >= "3.10" and python_version < "4.0" \
|
||||
--hash=sha256:9850fe97a166a794d7e920590d5ec49a05488884c9fc8b5dba8561effab0c2a0 \
|
||||
--hash=sha256:ecc5949310d9dae4de64648a4ce529f86df1f232ce23dcfefe737c24d21dfbe9
|
||||
@ -268,6 +268,9 @@ pandas==2.2.3 ; python_version >= "3.10" and python_version < "4.0" \
|
||||
--hash=sha256:f00d1345d84d8c86a63e476bb4955e46458b304b9575dcf71102b5c705320015 \
|
||||
--hash=sha256:f3a255b2c19987fbbe62a9dfd6cff7ff2aa9ccab3fc75218fd4b7530f01efa24 \
|
||||
--hash=sha256:fffb8ae78d8af97f849404f21411c95062db1496aeb3e56f146f0355c9989319
|
||||
pika==1.3.2 ; python_version >= "3.10" and python_version < "4.0" \
|
||||
--hash=sha256:0779a7c1fafd805672796085560d290213a465e4f6f76a6fb19e378d8041a14f \
|
||||
--hash=sha256:b2a327ddddf8570b4965b3576ac77091b850262d34ce8c1d8cb4e4146aa4145f
|
||||
pluggy==1.5.0 ; python_version >= "3.10" and python_version < "4.0" \
|
||||
--hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \
|
||||
--hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669
|
||||
|
24
src/app.py
24
src/app.py
@ -3,7 +3,7 @@ import os
|
||||
import src.model.trainer as model_trainer
|
||||
import src.web.server as web_server
|
||||
import src.model.updater as model_updater
|
||||
|
||||
from src.transport.rabbitmq import RabbitMQ
|
||||
|
||||
parser = argparse.ArgumentParser(prog='app.py')
|
||||
parser.add_argument('-i', '--init', action=argparse.BooleanOptionalAction, help='Initializing, must be run beforehand, --dataset is required')
|
||||
@ -21,6 +21,11 @@ _models_dir = os.getenv("MODELS_DIR")
|
||||
_fucking_dir = os.getenv("FUCKING_DIR")
|
||||
_web_api_url = os.getenv("WEB_API_URL")
|
||||
_token = os.getenv("TOKEN")
|
||||
_rabbitmq_host = os.getenv("RABBITMQ_HOST")
|
||||
_rabbitmq_port = int(os.getenv("RABBITMQ_PORT"))
|
||||
_rabbitmq_user = os.getenv("RABBITMQ_USER")
|
||||
_rabbitmq_pass = os.getenv("RABBITMQ_PASS")
|
||||
_rabbitmq_queue = os.getenv("RABBITMQ_QUEUE")
|
||||
|
||||
|
||||
def start():
|
||||
@ -28,11 +33,24 @@ def start():
|
||||
assert args.dataset is not None, "Dataset is required, run --help"
|
||||
model_trainer.train(args.dataset, fucking_path=_fucking_dir, backup_path=_models_dir)
|
||||
elif args.decision_maker:
|
||||
rabbitmq = RabbitMQ(_rabbitmq_host, _rabbitmq_port, _rabbitmq_user, _rabbitmq_pass)
|
||||
model_trainer.apply_latest(fucking_path=_fucking_dir, backup_path=_models_dir)
|
||||
web_server.start(port=_port, token=_token, fucking_path=_fucking_dir, backup_path=_models_dir)
|
||||
web_server.start(port=_port,
|
||||
token=_token,
|
||||
fucking_path=_fucking_dir,
|
||||
backup_path=_models_dir,
|
||||
rabbitmq=rabbitmq,
|
||||
queue=_rabbitmq_queue)
|
||||
elif args.model_updater:
|
||||
rabbitmq = RabbitMQ(_rabbitmq_host, _rabbitmq_port, _rabbitmq_user, _rabbitmq_pass)
|
||||
assert args.dataset is not None, "Dataset is required, run --help"
|
||||
model_updater.start(fucking_path=_fucking_dir, models_dir=_models_dir, dataset_path=args.dataset, web_api_url=_web_api_url, token=_token)
|
||||
model_updater.start(fucking_path=_fucking_dir,
|
||||
models_dir=_models_dir,
|
||||
dataset_path=args.dataset,
|
||||
web_api_url=_web_api_url,
|
||||
token=_token,
|
||||
rabbitmq=rabbitmq,
|
||||
queue=_rabbitmq_queue)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -1,3 +1,4 @@
|
||||
import csv
|
||||
import datetime as dt
|
||||
from src.l.logger import logger
|
||||
import src.model.trainer as trainer
|
||||
@ -5,13 +6,24 @@ from scheduler import Scheduler
|
||||
from os import listdir
|
||||
import time
|
||||
import requests
|
||||
from src.transport.rabbitmq import RabbitMQ
|
||||
from src.transport.train_dto import TrainDto
|
||||
|
||||
|
||||
def _does_file_exist_in_dir(path):
|
||||
return len(listdir(path)) > 0
|
||||
|
||||
def start(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str) -> None:
|
||||
def _listen_to_trainings(csv_file: str, rabbitmq: RabbitMQ, queue: str) -> None:
|
||||
def _callback(ch, method, properties, body):
|
||||
dto = TrainDto.from_json(body)
|
||||
with open(csv_file, "a") as f:
|
||||
writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
||||
writer.writerow([dto.is_spam, dto.text])
|
||||
rabbitmq.consume(queue_name=queue, callback=_callback, auto_ack=True)
|
||||
|
||||
def start(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str, rabbitmq: RabbitMQ, queue: str) -> None:
|
||||
logger.info("Starting...")
|
||||
_listen_to_trainings(csv_file=dataset_path, rabbitmq=rabbitmq, queue=queue)
|
||||
|
||||
def _restart_web_api() -> None:
|
||||
headers = { 'Authorization': f"Bearer {token}" }
|
||||
|
48
src/transport/rabbitmq.py
Normal file
48
src/transport/rabbitmq.py
Normal file
@ -0,0 +1,48 @@
|
||||
from typing import Optional
|
||||
|
||||
import pika
|
||||
import os
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
|
||||
|
||||
class RabbitMQ:
|
||||
def __init__(self, host: str, port: int, user: str, passwd: str):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.user = user
|
||||
self.password = passwd
|
||||
self.connection = None
|
||||
self.channel: Optional[BlockingChannel] = None
|
||||
self.connect()
|
||||
|
||||
def connect(self):
|
||||
credentials = pika.PlainCredentials(self.user, self.password)
|
||||
parameters = pika.ConnectionParameters(host=self.host, port=self.port, credentials=credentials)
|
||||
self.connection = pika.BlockingConnection(parameters)
|
||||
self.channel = self.connection.channel()
|
||||
|
||||
def consumer_ack(self):
|
||||
self.channel.basic_ack(delivery_tag=1, multiple=True)
|
||||
|
||||
def close(self):
|
||||
if self.connection and not self.connection.is_closed:
|
||||
self.connection.close()
|
||||
|
||||
def consume(self, queue_name, callback, auto_ack: bool = True):
|
||||
if not self.channel:
|
||||
raise Exception("Connection is not established.")
|
||||
self.channel.basic_consume(queue=queue_name, on_message_callback=callback, auto_ack=auto_ack)
|
||||
self.channel.start_consuming()
|
||||
|
||||
def publish(self, queue_name, message):
|
||||
if not self.channel:
|
||||
raise Exception("Connection is not established.")
|
||||
self.channel.queue_declare(queue=queue_name, durable=True)
|
||||
self.channel.basic_publish(exchange='',
|
||||
routing_key=queue_name,
|
||||
body=message,
|
||||
properties=pika.BasicProperties(
|
||||
delivery_mode=2, # make message persistent
|
||||
))
|
||||
print(f"Sent message to queue {queue_name}: {message}")
|
23
src/transport/train_dto.py
Normal file
23
src/transport/train_dto.py
Normal file
@ -0,0 +1,23 @@
|
||||
import json
|
||||
from typing import TypeVar, Type
|
||||
|
||||
|
||||
T = TypeVar('T', bound='TrainDto')
|
||||
|
||||
class TrainDto:
|
||||
def __init__(self, is_spam: bool, text: str):
|
||||
self.is_spam = is_spam
|
||||
self.text = text
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps({
|
||||
'is_spam': self.is_spam,
|
||||
'text': self.text
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def from_json(cls: Type[T], s: str) -> T:
|
||||
j = json.loads(s)
|
||||
if not 'is_spam' in j or not 'text' in j:
|
||||
raise Exception("Wrong format")
|
||||
return TrainDto(is_spam=j['is_spam'], text=j['text'])
|
@ -4,7 +4,8 @@ import os
|
||||
import tornado
|
||||
import src.model.trainer as model_trainer
|
||||
from src.l.logger import logger
|
||||
|
||||
from src.transport.rabbitmq import RabbitMQ
|
||||
from src.transport.train_dto import TrainDto
|
||||
|
||||
_spam_detector = None
|
||||
|
||||
@ -12,7 +13,7 @@ _spam_detector = None
|
||||
def _json(data) -> str:
|
||||
return json.dumps(data)
|
||||
|
||||
def start(port: int, token: str, fucking_path: str, backup_path: str) -> None:
|
||||
def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq: RabbitMQ, queue: str) -> None:
|
||||
global _spam_detector
|
||||
logger.info("Starting...")
|
||||
|
||||
@ -28,11 +29,22 @@ def start(port: int, token: str, fucking_path: str, backup_path: str) -> None:
|
||||
def post(self):
|
||||
body = json.loads(self.request.body)
|
||||
if not 'text' in body:
|
||||
self.write_error(400, body=_json({"error": "text is not specified"}))
|
||||
self.set_status(400)
|
||||
self.write_error(400, body=_json({ "error": "text is not specified" }))
|
||||
else:
|
||||
r = json.dumps({"is_spam": _spam_detector.is_spam(body['text'])})
|
||||
r = _json({ "is_spam": _spam_detector.is_spam(body['text']) })
|
||||
self.write(r)
|
||||
|
||||
class AdminTrainHandler(tornado.web.RequestHandler):
|
||||
@tornado.gen.coroutine
|
||||
def post(self):
|
||||
req = json.loads(self.request.body)
|
||||
if not 'decision' in req or not 'text' in req:
|
||||
self.set_status(400)
|
||||
self.write(_json({ 'status': 'fail', 'message': 'wrong format' }))
|
||||
else:
|
||||
rabbitmq.publish(queue, TrainDto(is_spam=bool(req['is_spam']), text=req['text']).to_json())
|
||||
|
||||
class AdminRestartHandler(tornado.web.RequestHandler):
|
||||
@tornado.gen.coroutine
|
||||
def post(self):
|
||||
@ -46,15 +58,14 @@ def start(port: int, token: str, fucking_path: str, backup_path: str) -> None:
|
||||
_spam_detector = _create_spam_detector()
|
||||
else:
|
||||
self.set_status(403)
|
||||
self.write({'status': 'fail', 'message': 'Invalid authentication token'})
|
||||
self.finish()
|
||||
return
|
||||
self.write(_json({ 'status': 'fail', 'message': 'Invalid authentication token' }))
|
||||
|
||||
async def start_web_server():
|
||||
logger.info(f"Starting web server on port {port}")
|
||||
app = tornado.web.Application(
|
||||
[
|
||||
(r"/check-spam", CheckSpamHandler),
|
||||
(r"/admin/train", AdminTrainHandler),
|
||||
(r"/admin/restart", AdminRestartHandler)
|
||||
],
|
||||
template_path=os.path.join(os.path.dirname(__file__), "templates"),
|
||||
|
Loading…
x
Reference in New Issue
Block a user