communication over rabbitmq
This commit is contained in:
parent
c6ab8791ca
commit
6c97e0c3f6
12
Dockerfile
12
Dockerfile
@ -40,13 +40,23 @@ ENV PYTHONPATH=/app
|
|||||||
ARG DATASET
|
ARG DATASET
|
||||||
ARG PORT
|
ARG PORT
|
||||||
ARG WEB_API_URL
|
ARG WEB_API_URL
|
||||||
|
ARG RABBITMQ_HOST
|
||||||
|
ARG RABBITMQ_PORT
|
||||||
|
ARG RABBITMQ_USER
|
||||||
|
ARG RABBITMQ_PASS
|
||||||
|
ARG RABBITMQ_QUEUE
|
||||||
|
|
||||||
ENV PORT=${PORT} \
|
ENV PORT=${PORT} \
|
||||||
DATASET=${DATASET} \
|
DATASET=${DATASET} \
|
||||||
WORKING_DIR=/app/nltk_data \
|
WORKING_DIR=/app/nltk_data \
|
||||||
FUCKING_DIR=/usr/local/lib/python3.10/site-packages/spam_detector_ai/models \
|
FUCKING_DIR=/usr/local/lib/python3.10/site-packages/spam_detector_ai/models \
|
||||||
MODELS_DIR=/app/models \
|
MODELS_DIR=/app/models \
|
||||||
WEB_API_URL=${WEB_API_URL}
|
WEB_API_URL=${WEB_API_URL} \
|
||||||
|
RABBITMQ_HOST=${RABBITMQ_HOST} \
|
||||||
|
RABBITMQ_PORT=${RABBITMQ_PORT} \
|
||||||
|
RABBITMQ_USER=${RABBITMQ_USER} \
|
||||||
|
RABBITMQ_PASS=${RABBITMQ_PASS} \
|
||||||
|
RABBITMQ_QUEUE=${RABBITMQ_QUEUE}
|
||||||
|
|
||||||
RUN python3 -m src.preparer
|
RUN python3 -m src.preparer
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
networks:
|
networks:
|
||||||
spam-detector-internal:
|
spam-detector-internal:
|
||||||
driver: bridge
|
driver: bridge
|
||||||
|
rabbitmq:
|
||||||
|
external: true
|
||||||
|
|
||||||
|
|
||||||
services:
|
services:
|
||||||
@ -11,6 +13,11 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
- PORT=${PORT:-8080}
|
- PORT=${PORT:-8080}
|
||||||
- TOKEN=token12345
|
- TOKEN=token12345
|
||||||
|
- RABBITMQ_HOST=${RABBITMQ_HOST}
|
||||||
|
- RABBITMQ_PORT=${RABBITMQ_PORT:-5672}
|
||||||
|
- RABBITMQ_USER=${RABBITMQ_USER}
|
||||||
|
- RABBITMQ_PASS=${RABBITMQ_PASS}
|
||||||
|
- RABBITMQ_QUEUE=${RABBITMQ_QUEUE}
|
||||||
entrypoint: [ "python", "-m", "src.app", "-m" ]
|
entrypoint: [ "python", "-m", "src.app", "-m" ]
|
||||||
volumes:
|
volumes:
|
||||||
- type: bind
|
- type: bind
|
||||||
@ -22,6 +29,7 @@ services:
|
|||||||
- "8080:${PORT:-8080}"
|
- "8080:${PORT:-8080}"
|
||||||
networks:
|
networks:
|
||||||
- spam-detector-internal
|
- spam-detector-internal
|
||||||
|
- rabbitmq
|
||||||
|
|
||||||
model-updater:
|
model-updater:
|
||||||
build: ./
|
build: ./
|
||||||
@ -29,6 +37,11 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
- WEB_API_URL=http://spam-detector-decision-maker:${PORT:-8080}
|
- WEB_API_URL=http://spam-detector-decision-maker:${PORT:-8080}
|
||||||
- TOKEN=token12345
|
- TOKEN=token12345
|
||||||
|
- RABBITMQ_HOST=${RABBITMQ_HOST}
|
||||||
|
- RABBITMQ_PORT=${RABBITMQ_PORT:-5672}
|
||||||
|
- RABBITMQ_USER=${RABBITMQ_USER}
|
||||||
|
- RABBITMQ_PASS=${RABBITMQ_PASS}
|
||||||
|
- RABBITMQ_QUEUE=${RABBITMQ_QUEUE}
|
||||||
entrypoint: [ "python", "-m", "src.app", "-u", "-d", "/app/dataset.csv" ]
|
entrypoint: [ "python", "-m", "src.app", "-u", "-d", "/app/dataset.csv" ]
|
||||||
# entrypoint: [ "bash", "-c", "while true; do sleep 1; done" ]
|
# entrypoint: [ "bash", "-c", "while true; do sleep 1; done" ]
|
||||||
volumes:
|
volumes:
|
||||||
@ -38,3 +51,4 @@ services:
|
|||||||
- ${MODELS_DIR}:/app/models/
|
- ${MODELS_DIR}:/app/models/
|
||||||
networks:
|
networks:
|
||||||
- spam-detector-internal
|
- spam-detector-internal
|
||||||
|
- rabbitmq
|
||||||
|
@ -24,6 +24,7 @@ tornado = "^6.4.1"
|
|||||||
asyncio = "^3.4.3"
|
asyncio = "^3.4.3"
|
||||||
exceptiongroup = "^1.0.0rc8"
|
exceptiongroup = "^1.0.0rc8"
|
||||||
scheduler = "^0.8.7"
|
scheduler = "^0.8.7"
|
||||||
|
pika = "^1.3.2"
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
app = "src.app:start"
|
app = "src.app:start"
|
||||||
|
@ -135,9 +135,9 @@ h2==3.2.0 ; python_version >= "3.10" and python_version < "4.0" \
|
|||||||
hpack==3.0.0 ; python_version >= "3.10" and python_version < "4.0" \
|
hpack==3.0.0 ; python_version >= "3.10" and python_version < "4.0" \
|
||||||
--hash=sha256:0edd79eda27a53ba5be2dfabf3b15780928a0dff6eb0c60a3d6767720e970c89 \
|
--hash=sha256:0edd79eda27a53ba5be2dfabf3b15780928a0dff6eb0c60a3d6767720e970c89 \
|
||||||
--hash=sha256:8eec9c1f4bfae3408a3f30500261f7e6a65912dc138526ea054f9ad98892e9d2
|
--hash=sha256:8eec9c1f4bfae3408a3f30500261f7e6a65912dc138526ea054f9ad98892e9d2
|
||||||
hstspreload==2024.10.1 ; python_version >= "3.10" and python_version < "4.0" \
|
hstspreload==2024.11.1 ; python_version >= "3.10" and python_version < "4.0" \
|
||||||
--hash=sha256:2859a6b52253743ddddad468d8c9570ba650170ca49ac416336826915ee409b8 \
|
--hash=sha256:1dc00fd6517284ec32ca0e0955bd5de9d1b1475c2ad196cb9e2933dc05a51d6e \
|
||||||
--hash=sha256:3ab481036cbdff095cb411dafe33ee7924492319cf6ddaf4e776a159537541b3
|
--hash=sha256:e0b18112e122e1cac8ca59c8972079f7c688912205f8c81f5ba7cb6c66e05dda
|
||||||
httpcore==0.9.1 ; python_version >= "3.10" and python_version < "4.0" \
|
httpcore==0.9.1 ; python_version >= "3.10" and python_version < "4.0" \
|
||||||
--hash=sha256:9850fe97a166a794d7e920590d5ec49a05488884c9fc8b5dba8561effab0c2a0 \
|
--hash=sha256:9850fe97a166a794d7e920590d5ec49a05488884c9fc8b5dba8561effab0c2a0 \
|
||||||
--hash=sha256:ecc5949310d9dae4de64648a4ce529f86df1f232ce23dcfefe737c24d21dfbe9
|
--hash=sha256:ecc5949310d9dae4de64648a4ce529f86df1f232ce23dcfefe737c24d21dfbe9
|
||||||
@ -268,6 +268,9 @@ pandas==2.2.3 ; python_version >= "3.10" and python_version < "4.0" \
|
|||||||
--hash=sha256:f00d1345d84d8c86a63e476bb4955e46458b304b9575dcf71102b5c705320015 \
|
--hash=sha256:f00d1345d84d8c86a63e476bb4955e46458b304b9575dcf71102b5c705320015 \
|
||||||
--hash=sha256:f3a255b2c19987fbbe62a9dfd6cff7ff2aa9ccab3fc75218fd4b7530f01efa24 \
|
--hash=sha256:f3a255b2c19987fbbe62a9dfd6cff7ff2aa9ccab3fc75218fd4b7530f01efa24 \
|
||||||
--hash=sha256:fffb8ae78d8af97f849404f21411c95062db1496aeb3e56f146f0355c9989319
|
--hash=sha256:fffb8ae78d8af97f849404f21411c95062db1496aeb3e56f146f0355c9989319
|
||||||
|
pika==1.3.2 ; python_version >= "3.10" and python_version < "4.0" \
|
||||||
|
--hash=sha256:0779a7c1fafd805672796085560d290213a465e4f6f76a6fb19e378d8041a14f \
|
||||||
|
--hash=sha256:b2a327ddddf8570b4965b3576ac77091b850262d34ce8c1d8cb4e4146aa4145f
|
||||||
pluggy==1.5.0 ; python_version >= "3.10" and python_version < "4.0" \
|
pluggy==1.5.0 ; python_version >= "3.10" and python_version < "4.0" \
|
||||||
--hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \
|
--hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \
|
||||||
--hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669
|
--hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669
|
||||||
|
24
src/app.py
24
src/app.py
@ -3,7 +3,7 @@ import os
|
|||||||
import src.model.trainer as model_trainer
|
import src.model.trainer as model_trainer
|
||||||
import src.web.server as web_server
|
import src.web.server as web_server
|
||||||
import src.model.updater as model_updater
|
import src.model.updater as model_updater
|
||||||
|
from src.transport.rabbitmq import RabbitMQ
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(prog='app.py')
|
parser = argparse.ArgumentParser(prog='app.py')
|
||||||
parser.add_argument('-i', '--init', action=argparse.BooleanOptionalAction, help='Initializing, must be run beforehand, --dataset is required')
|
parser.add_argument('-i', '--init', action=argparse.BooleanOptionalAction, help='Initializing, must be run beforehand, --dataset is required')
|
||||||
@ -21,6 +21,11 @@ _models_dir = os.getenv("MODELS_DIR")
|
|||||||
_fucking_dir = os.getenv("FUCKING_DIR")
|
_fucking_dir = os.getenv("FUCKING_DIR")
|
||||||
_web_api_url = os.getenv("WEB_API_URL")
|
_web_api_url = os.getenv("WEB_API_URL")
|
||||||
_token = os.getenv("TOKEN")
|
_token = os.getenv("TOKEN")
|
||||||
|
_rabbitmq_host = os.getenv("RABBITMQ_HOST")
|
||||||
|
_rabbitmq_port = int(os.getenv("RABBITMQ_PORT"))
|
||||||
|
_rabbitmq_user = os.getenv("RABBITMQ_USER")
|
||||||
|
_rabbitmq_pass = os.getenv("RABBITMQ_PASS")
|
||||||
|
_rabbitmq_queue = os.getenv("RABBITMQ_QUEUE")
|
||||||
|
|
||||||
|
|
||||||
def start():
|
def start():
|
||||||
@ -28,11 +33,24 @@ def start():
|
|||||||
assert args.dataset is not None, "Dataset is required, run --help"
|
assert args.dataset is not None, "Dataset is required, run --help"
|
||||||
model_trainer.train(args.dataset, fucking_path=_fucking_dir, backup_path=_models_dir)
|
model_trainer.train(args.dataset, fucking_path=_fucking_dir, backup_path=_models_dir)
|
||||||
elif args.decision_maker:
|
elif args.decision_maker:
|
||||||
|
rabbitmq = RabbitMQ(_rabbitmq_host, _rabbitmq_port, _rabbitmq_user, _rabbitmq_pass)
|
||||||
model_trainer.apply_latest(fucking_path=_fucking_dir, backup_path=_models_dir)
|
model_trainer.apply_latest(fucking_path=_fucking_dir, backup_path=_models_dir)
|
||||||
web_server.start(port=_port, token=_token, fucking_path=_fucking_dir, backup_path=_models_dir)
|
web_server.start(port=_port,
|
||||||
|
token=_token,
|
||||||
|
fucking_path=_fucking_dir,
|
||||||
|
backup_path=_models_dir,
|
||||||
|
rabbitmq=rabbitmq,
|
||||||
|
queue=_rabbitmq_queue)
|
||||||
elif args.model_updater:
|
elif args.model_updater:
|
||||||
|
rabbitmq = RabbitMQ(_rabbitmq_host, _rabbitmq_port, _rabbitmq_user, _rabbitmq_pass)
|
||||||
assert args.dataset is not None, "Dataset is required, run --help"
|
assert args.dataset is not None, "Dataset is required, run --help"
|
||||||
model_updater.start(fucking_path=_fucking_dir, models_dir=_models_dir, dataset_path=args.dataset, web_api_url=_web_api_url, token=_token)
|
model_updater.start(fucking_path=_fucking_dir,
|
||||||
|
models_dir=_models_dir,
|
||||||
|
dataset_path=args.dataset,
|
||||||
|
web_api_url=_web_api_url,
|
||||||
|
token=_token,
|
||||||
|
rabbitmq=rabbitmq,
|
||||||
|
queue=_rabbitmq_queue)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import csv
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
from src.l.logger import logger
|
from src.l.logger import logger
|
||||||
import src.model.trainer as trainer
|
import src.model.trainer as trainer
|
||||||
@ -5,13 +6,24 @@ from scheduler import Scheduler
|
|||||||
from os import listdir
|
from os import listdir
|
||||||
import time
|
import time
|
||||||
import requests
|
import requests
|
||||||
|
from src.transport.rabbitmq import RabbitMQ
|
||||||
|
from src.transport.train_dto import TrainDto
|
||||||
|
|
||||||
|
|
||||||
def _does_file_exist_in_dir(path):
|
def _does_file_exist_in_dir(path):
|
||||||
return len(listdir(path)) > 0
|
return len(listdir(path)) > 0
|
||||||
|
|
||||||
def start(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str) -> None:
|
def _listen_to_trainings(csv_file: str, rabbitmq: RabbitMQ, queue: str) -> None:
|
||||||
|
def _callback(ch, method, properties, body):
|
||||||
|
dto = TrainDto.from_json(body)
|
||||||
|
with open(csv_file, "a") as f:
|
||||||
|
writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
||||||
|
writer.writerow([dto.is_spam, dto.text])
|
||||||
|
rabbitmq.consume(queue_name=queue, callback=_callback, auto_ack=True)
|
||||||
|
|
||||||
|
def start(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str, rabbitmq: RabbitMQ, queue: str) -> None:
|
||||||
logger.info("Starting...")
|
logger.info("Starting...")
|
||||||
|
_listen_to_trainings(csv_file=dataset_path, rabbitmq=rabbitmq, queue=queue)
|
||||||
|
|
||||||
def _restart_web_api() -> None:
|
def _restart_web_api() -> None:
|
||||||
headers = { 'Authorization': f"Bearer {token}" }
|
headers = { 'Authorization': f"Bearer {token}" }
|
||||||
|
48
src/transport/rabbitmq.py
Normal file
48
src/transport/rabbitmq.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pika
|
||||||
|
import os
|
||||||
|
|
||||||
|
from pika.adapters.blocking_connection import BlockingChannel
|
||||||
|
|
||||||
|
|
||||||
|
class RabbitMQ:
|
||||||
|
def __init__(self, host: str, port: int, user: str, passwd: str):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.user = user
|
||||||
|
self.password = passwd
|
||||||
|
self.connection = None
|
||||||
|
self.channel: Optional[BlockingChannel] = None
|
||||||
|
self.connect()
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
credentials = pika.PlainCredentials(self.user, self.password)
|
||||||
|
parameters = pika.ConnectionParameters(host=self.host, port=self.port, credentials=credentials)
|
||||||
|
self.connection = pika.BlockingConnection(parameters)
|
||||||
|
self.channel = self.connection.channel()
|
||||||
|
|
||||||
|
def consumer_ack(self):
|
||||||
|
self.channel.basic_ack(delivery_tag=1, multiple=True)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.connection and not self.connection.is_closed:
|
||||||
|
self.connection.close()
|
||||||
|
|
||||||
|
def consume(self, queue_name, callback, auto_ack: bool = True):
|
||||||
|
if not self.channel:
|
||||||
|
raise Exception("Connection is not established.")
|
||||||
|
self.channel.basic_consume(queue=queue_name, on_message_callback=callback, auto_ack=auto_ack)
|
||||||
|
self.channel.start_consuming()
|
||||||
|
|
||||||
|
def publish(self, queue_name, message):
|
||||||
|
if not self.channel:
|
||||||
|
raise Exception("Connection is not established.")
|
||||||
|
self.channel.queue_declare(queue=queue_name, durable=True)
|
||||||
|
self.channel.basic_publish(exchange='',
|
||||||
|
routing_key=queue_name,
|
||||||
|
body=message,
|
||||||
|
properties=pika.BasicProperties(
|
||||||
|
delivery_mode=2, # make message persistent
|
||||||
|
))
|
||||||
|
print(f"Sent message to queue {queue_name}: {message}")
|
23
src/transport/train_dto.py
Normal file
23
src/transport/train_dto.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import json
|
||||||
|
from typing import TypeVar, Type
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar('T', bound='TrainDto')
|
||||||
|
|
||||||
|
class TrainDto:
|
||||||
|
def __init__(self, is_spam: bool, text: str):
|
||||||
|
self.is_spam = is_spam
|
||||||
|
self.text = text
|
||||||
|
|
||||||
|
def to_json(self) -> str:
|
||||||
|
return json.dumps({
|
||||||
|
'is_spam': self.is_spam,
|
||||||
|
'text': self.text
|
||||||
|
})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls: Type[T], s: str) -> T:
|
||||||
|
j = json.loads(s)
|
||||||
|
if not 'is_spam' in j or not 'text' in j:
|
||||||
|
raise Exception("Wrong format")
|
||||||
|
return TrainDto(is_spam=j['is_spam'], text=j['text'])
|
@ -4,7 +4,8 @@ import os
|
|||||||
import tornado
|
import tornado
|
||||||
import src.model.trainer as model_trainer
|
import src.model.trainer as model_trainer
|
||||||
from src.l.logger import logger
|
from src.l.logger import logger
|
||||||
|
from src.transport.rabbitmq import RabbitMQ
|
||||||
|
from src.transport.train_dto import TrainDto
|
||||||
|
|
||||||
_spam_detector = None
|
_spam_detector = None
|
||||||
|
|
||||||
@ -12,7 +13,7 @@ _spam_detector = None
|
|||||||
def _json(data) -> str:
|
def _json(data) -> str:
|
||||||
return json.dumps(data)
|
return json.dumps(data)
|
||||||
|
|
||||||
def start(port: int, token: str, fucking_path: str, backup_path: str) -> None:
|
def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq: RabbitMQ, queue: str) -> None:
|
||||||
global _spam_detector
|
global _spam_detector
|
||||||
logger.info("Starting...")
|
logger.info("Starting...")
|
||||||
|
|
||||||
@ -28,11 +29,22 @@ def start(port: int, token: str, fucking_path: str, backup_path: str) -> None:
|
|||||||
def post(self):
|
def post(self):
|
||||||
body = json.loads(self.request.body)
|
body = json.loads(self.request.body)
|
||||||
if not 'text' in body:
|
if not 'text' in body:
|
||||||
self.write_error(400, body=_json({"error": "text is not specified"}))
|
self.set_status(400)
|
||||||
|
self.write_error(400, body=_json({ "error": "text is not specified" }))
|
||||||
else:
|
else:
|
||||||
r = json.dumps({"is_spam": _spam_detector.is_spam(body['text'])})
|
r = _json({ "is_spam": _spam_detector.is_spam(body['text']) })
|
||||||
self.write(r)
|
self.write(r)
|
||||||
|
|
||||||
|
class AdminTrainHandler(tornado.web.RequestHandler):
|
||||||
|
@tornado.gen.coroutine
|
||||||
|
def post(self):
|
||||||
|
req = json.loads(self.request.body)
|
||||||
|
if not 'decision' in req or not 'text' in req:
|
||||||
|
self.set_status(400)
|
||||||
|
self.write(_json({ 'status': 'fail', 'message': 'wrong format' }))
|
||||||
|
else:
|
||||||
|
rabbitmq.publish(queue, TrainDto(is_spam=bool(req['is_spam']), text=req['text']).to_json())
|
||||||
|
|
||||||
class AdminRestartHandler(tornado.web.RequestHandler):
|
class AdminRestartHandler(tornado.web.RequestHandler):
|
||||||
@tornado.gen.coroutine
|
@tornado.gen.coroutine
|
||||||
def post(self):
|
def post(self):
|
||||||
@ -46,15 +58,14 @@ def start(port: int, token: str, fucking_path: str, backup_path: str) -> None:
|
|||||||
_spam_detector = _create_spam_detector()
|
_spam_detector = _create_spam_detector()
|
||||||
else:
|
else:
|
||||||
self.set_status(403)
|
self.set_status(403)
|
||||||
self.write({'status': 'fail', 'message': 'Invalid authentication token'})
|
self.write(_json({ 'status': 'fail', 'message': 'Invalid authentication token' }))
|
||||||
self.finish()
|
|
||||||
return
|
|
||||||
|
|
||||||
async def start_web_server():
|
async def start_web_server():
|
||||||
logger.info(f"Starting web server on port {port}")
|
logger.info(f"Starting web server on port {port}")
|
||||||
app = tornado.web.Application(
|
app = tornado.web.Application(
|
||||||
[
|
[
|
||||||
(r"/check-spam", CheckSpamHandler),
|
(r"/check-spam", CheckSpamHandler),
|
||||||
|
(r"/admin/train", AdminTrainHandler),
|
||||||
(r"/admin/restart", AdminRestartHandler)
|
(r"/admin/restart", AdminRestartHandler)
|
||||||
],
|
],
|
||||||
template_path=os.path.join(os.path.dirname(__file__), "templates"),
|
template_path=os.path.join(os.path.dirname(__file__), "templates"),
|
||||||
|
Loading…
Reference in New Issue
Block a user