Dockerfile for faster startup, model-initer in compose, fixes and updates

This commit is contained in:
bvn13 2025-01-20 23:15:23 +03:00
parent 3a1b9cc4c8
commit 6d8302ec1f
13 changed files with 1236 additions and 24 deletions

2
.gitignore vendored
View File

@ -4,3 +4,5 @@ env/**
.idea/** .idea/**
models models
models/** models/**
samples
samples/**

View File

@ -8,7 +8,6 @@ WORKDIR /app
COPY lib /app/lib COPY lib /app/lib
COPY requirements.txt /app COPY requirements.txt /app
COPY src /app/src
####################################################################################################################### #######################################################################################################################
@ -35,6 +34,8 @@ RUN apk add libgomp libstdc++
RUN mkdir /app/nltk_data RUN mkdir /app/nltk_data
RUN ln -s /app/nltk_data /root/nltk_data RUN ln -s /app/nltk_data /root/nltk_data
COPY src /app/src
ENV PYTHONPATH=/app ENV PYTHONPATH=/app
ARG DATASET ARG DATASET

View File

@ -10,6 +10,8 @@ services:
decision-maker: decision-maker:
build: ./ build: ./
container_name: spam-detector-decision-maker container_name: spam-detector-decision-maker
depends_on:
- model-initer
environment: environment:
- TOKEN=token12345 - TOKEN=token12345
- RABBITMQ_HOST=${RABBITMQ_HOST} - RABBITMQ_HOST=${RABBITMQ_HOST}
@ -30,6 +32,24 @@ services:
- spam-detector-internal - spam-detector-internal
- rabbitmq - rabbitmq
model-initer:
build: ./
container_name: spam-detector-model-initer
environment:
- TOKEN=token12345
- RABBITMQ_HOST=${RABBITMQ_HOST}
- RABBITMQ_PORT=${RABBITMQ_PORT:-5672}
- RABBITMQ_USER=${RABBITMQ_USER}
- RABBITMQ_PASS=${RABBITMQ_PASS}
- RABBITMQ_QUEUE=${RABBITMQ_QUEUE}
entrypoint: [ "python", "-m", "src.app", "-i", "-d", "/app/dataset.csv" ]
volumes:
- type: bind
source: $DATASET
target: /app/dataset.csv
- ${MODELS_DIR}:/app/models
restart: no
model-updater: model-updater:
build: ./ build: ./
container_name: spam-detector-model-updater container_name: spam-detector-model-updater

1134
poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

2
rabbitmq/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
data
data/**

View File

@ -0,0 +1,20 @@
networks:
rabbitmq:
external: true
services:
rabbitmq:
image: rabbitmq:3.13.6-management
hostname: rabbitmq
restart: always
environment:
- RABBITMQ_DEFAULT_USER=rmuser
- RABBITMQ_DEFAULT_PASS=${RABBITMQ_PASS}
- RABBITMQ_SERVER_ADDITIONAL_ERL_ARGS=-rabbit log_levels [{connection,error},{default,error}] disk_free_limit 2147483648
volumes:
- ${CONTAINER}:/var/lib/rabbitmq
ports:
- 15672:15672
- 5672:5672
networks:
- rabbitmq

5
rabbitmq/rebuild.sh Executable file
View File

@ -0,0 +1,5 @@
#!/bin/bash
docker-compose up -d --force-recreate --no-deps --build

View File

@ -31,6 +31,8 @@ _rabbitmq_queue = os.getenv("RABBITMQ_QUEUE")
def start(): def start():
if args.init: if args.init:
assert args.dataset is not None, "Dataset is required, run --help" assert args.dataset is not None, "Dataset is required, run --help"
dataset_size = os.path.getsize(args.dataset)
print(f"Dataset size, bytes: {dataset_size}")
model_trainer.train(args.dataset, fucking_path=_fucking_dir, backup_path=_models_dir) model_trainer.train(args.dataset, fucking_path=_fucking_dir, backup_path=_models_dir)
elif args.decision_maker: elif args.decision_maker:
rabbitmq = RabbitMQ(_rabbitmq_host, _rabbitmq_port, _rabbitmq_user, _rabbitmq_pass) rabbitmq = RabbitMQ(_rabbitmq_host, _rabbitmq_port, _rabbitmq_user, _rabbitmq_pass)

View File

@ -1,5 +1,18 @@
import logging import logging
import os
logging.basicConfig(format="%(asctime)s | %(name)s | %(levelname)s | %(message)s") LEVEL = os.getenv('LOG_LEVEL')
logger = logging.getLogger(__package__) if LEVEL is None:
logger.setLevel(logging.INFO) LEVEL = logging.INFO
#logging.basicConfig(format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
logging.basicConfig(
format='%(asctime)s,%(msecs)d | %(levelname)-8s | %(name)s:%(funcName)s:%(lineno)d | %(message)s',
datefmt='%Y-%m-%d:%H:%M:%S',
level=LEVEL
)
def logger(pkg: str, level:int = logging.INFO):
lgr = logging.getLogger(pkg)
lgr.setLevel(level)
return lgr

View File

@ -4,12 +4,14 @@ import shutil
from os.path import isfile, join from os.path import isfile, join
from src.l.logger import logger from src.l.logger import logger
log = logger(__name__)
def _does_file_exist_in_dir(path): def _does_file_exist_in_dir(path):
return any(isfile(join(path, i)) for i in os.listdir(path)) return any(isfile(join(path, i)) for i in os.listdir(path))
def apply_latest(fucking_path: str, backup_path: str) -> None: def apply_latest(fucking_path: str, backup_path: str) -> None:
logger.info("Applying models...") log.info("Applying models...")
if _does_file_exist_in_dir(backup_path): if _does_file_exist_in_dir(backup_path):
files = os.listdir(fucking_path) files = os.listdir(fucking_path)
for file in files: for file in files:
@ -24,15 +26,15 @@ def train(dataset_path: str, fucking_path: str, backup_path: str) -> None:
from spam_detector_ai.training.train_models import ModelTrainer from spam_detector_ai.training.train_models import ModelTrainer
def _train_model(classifier_type, model_filename, vectoriser_filename, X_train, y_train): def _train_model(classifier_type, model_filename, vectoriser_filename, X_train, y_train):
logger.info(f'Training {classifier_type}') log.info(f'Training {classifier_type}')
trainer_ = ModelTrainer(data=None, classifier_type=classifier_type, logger=logger) trainer_ = ModelTrainer(data=None, classifier_type=classifier_type, logger=log)
trainer_.train(X_train, y_train) trainer_.train(X_train, y_train)
trainer_.save_model(model_filename, vectoriser_filename) trainer_.save_model(model_filename, vectoriser_filename)
logger.info(f"Starting to train using data_path={dataset_path}") log.info(f"Starting to train using data_path={dataset_path}")
# Load and preprocess data once # Load and preprocess data once
# data_path = os.path.join(project_root, 'spam.csv') # data_path = os.path.join(project_root, 'spam.csv')
initial_trainer = ModelTrainer(data_path=dataset_path, logger=logger) initial_trainer = ModelTrainer(data_path=dataset_path, logger=log)
processed_data = initial_trainer.preprocess_data_() processed_data = initial_trainer.preprocess_data_()
# Split the data once # Split the data once
@ -49,15 +51,15 @@ def train(dataset_path: str, fucking_path: str, backup_path: str) -> None:
] ]
# Train each model with the pre-split data # Train each model with the pre-split data
logger.info(f"Train each model with the pre-split data\n") log.info(f"Train each model with the pre-split data\n")
for ct, mf, vf in configurations: for ct, mf, vf in configurations:
_train_model(ct, mf, vf, X__train, y__train) _train_model(ct, mf, vf, X__train, y__train)
logger.info("Backing up...") log.info("Backing up...")
files = os.listdir(backup_path) files = os.listdir(backup_path)
for file in files: for file in files:
shutil.rmtree(os.path.join(backup_path, file)) shutil.rmtree(os.path.join(backup_path, file))
shutil.copytree(fucking_path, backup_path, shutil.copytree(fucking_path, backup_path,
copy_function=shutil.copy2, copy_function=shutil.copy2,
dirs_exist_ok=True) dirs_exist_ok=True)
logger.info("Backing up - done") log.info("Backing up - done")

View File

@ -10,24 +10,26 @@ from src.transport.rabbitmq import RabbitMQ
from src.transport.train_dto import TrainDto from src.transport.train_dto import TrainDto
import threading import threading
log = logger(__name__)
def _does_file_exist_in_dir(path): def _does_file_exist_in_dir(path):
return len(listdir(path)) > 0 return len(listdir(path)) > 0
def _listen_to_trainings(csv_file: str, rabbitmq: RabbitMQ, queue: str) -> None: def _listen_to_trainings(csv_file: str, rabbitmq: RabbitMQ, queue: str) -> None:
def _callback(ch, method, properties, body): def _callback(ch, method, properties, body):
logger.info(f"Message consumed: {body}") log.info(f"Message consumed: {body}")
dto = TrainDto.from_json(body) dto = TrainDto.from_json(body)
logger.info(f"Message read as DTO: {body}") log.info(f"Message read as DTO: {body}")
with open(csv_file, "a") as f: with open(csv_file, "a") as f:
logger.info(f"Writing to dataset: {body}") log.info(f"Writing to dataset: {body}")
writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
writer.writerow([dto.spam_interpretation(), dto.text]) writer.writerow([dto.spam_interpretation(), dto.text])
logger.info(f"Message is done: {body}") log.info(f"Message is done: {body}")
rabbitmq.consume(queue_name=queue, callback=_callback, auto_ack=True) rabbitmq.consume(queue_name=queue, callback=_callback, auto_ack=True)
def _start_listening_to_trainings(csv_file: str, rabbitmq: RabbitMQ, queue: str) -> threading.Thread: def _start_listening_to_trainings(csv_file: str, rabbitmq: RabbitMQ, queue: str) -> threading.Thread:
logger.info("Starting listening to trainings") log.info("Starting listening to trainings")
t = threading.Thread(target=_listen_to_trainings, args=(csv_file, rabbitmq, queue), daemon=True) t = threading.Thread(target=_listen_to_trainings, args=(csv_file, rabbitmq, queue), daemon=True)
t.start() t.start()
return t return t
@ -42,7 +44,7 @@ def _start_scheduling(fucking_path: str, models_dir: str, dataset_path: str, web
timeout=3 timeout=3
) )
if response.status_code > 399: if response.status_code > 399:
logger.warn(f"Unable to restart Web API server: {response.status_code}, {response.text}") log.warn(f"Unable to restart Web API server: {response.status_code}, {response.text}")
def _train() -> None: def _train() -> None:
trainer.train(dataset_path=dataset_path, fucking_path=fucking_path, backup_path=models_dir) trainer.train(dataset_path=dataset_path, fucking_path=fucking_path, backup_path=models_dir)
@ -51,7 +53,7 @@ def _start_scheduling(fucking_path: str, models_dir: str, dataset_path: str, web
tz_moscow = dt.timezone(dt.timedelta(hours=3)) tz_moscow = dt.timezone(dt.timedelta(hours=3))
scheduler = Scheduler(tzinfo=dt.timezone.utc) scheduler = Scheduler(tzinfo=dt.timezone.utc)
if not _does_file_exist_in_dir(models_dir): if not _does_file_exist_in_dir(models_dir):
logger.info("Will be updated in 5 seconds...") log.info("Will be updated in 5 seconds...")
scheduler.once(dt.timedelta(seconds=5), _train) scheduler.once(dt.timedelta(seconds=5), _train)
scheduler.daily(dt.time(hour=3, minute=0, tzinfo=tz_moscow), _train) scheduler.daily(dt.time(hour=3, minute=0, tzinfo=tz_moscow), _train)
print(scheduler) print(scheduler)
@ -60,13 +62,13 @@ def _start_scheduling(fucking_path: str, models_dir: str, dataset_path: str, web
time.sleep(1) time.sleep(1)
def _start_scheduling_in_thread(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str) -> threading.Thread: def _start_scheduling_in_thread(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str) -> threading.Thread:
logger.info("Starting scheduling in thread") log.info("Starting scheduling in thread")
t = threading.Thread(target=_start_scheduling, args=(fucking_path, models_dir, dataset_path, web_api_url, token), daemon=True) t = threading.Thread(target=_start_scheduling, args=(fucking_path, models_dir, dataset_path, web_api_url, token), daemon=True)
t.start() t.start()
return t return t
def start(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str, rabbitmq: RabbitMQ, queue: str) -> None: def start(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str, rabbitmq: RabbitMQ, queue: str) -> None:
logger.info("Starting...") log.info("Starting...")
t1 = _start_listening_to_trainings(csv_file=dataset_path, rabbitmq=rabbitmq, queue=queue) t1 = _start_listening_to_trainings(csv_file=dataset_path, rabbitmq=rabbitmq, queue=queue)
t2 = _start_scheduling_in_thread(fucking_path=fucking_path, models_dir=models_dir, dataset_path=dataset_path, web_api_url=web_api_url, token=token) t2 = _start_scheduling_in_thread(fucking_path=fucking_path, models_dir=models_dir, dataset_path=dataset_path, web_api_url=web_api_url, token=token)
t1.join() t1.join()

View File

@ -3,6 +3,8 @@ import pika
from pika.adapters.blocking_connection import BlockingChannel from pika.adapters.blocking_connection import BlockingChannel
from src.l.logger import logger from src.l.logger import logger
log = logger(__name__)
class RabbitMQ: class RabbitMQ:
def __init__(self, host: str, port: int, user: str, passwd: str): def __init__(self, host: str, port: int, user: str, passwd: str):
@ -47,4 +49,4 @@ class RabbitMQ:
), ),
mandatory=True mandatory=True
) )
logger.info(f"Sent message to queue {queue_name}: {message}") log.info(f"Sent message to queue {queue_name}: {message}")

View File

@ -8,6 +8,7 @@ from src.transport.rabbitmq import RabbitMQ
from src.transport.train_dto import TrainDto from src.transport.train_dto import TrainDto
_spam_detector = None _spam_detector = None
log = logger(__name__)
def _json(data) -> str: def _json(data) -> str:
@ -15,7 +16,7 @@ def _json(data) -> str:
def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq: RabbitMQ, queue: str) -> None: def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq: RabbitMQ, queue: str) -> None:
global _spam_detector global _spam_detector
logger.info("Starting...") log.info("Starting...")
def _create_spam_detector(): def _create_spam_detector():
model_trainer.apply_latest(fucking_path=fucking_path, backup_path=backup_path) model_trainer.apply_latest(fucking_path=fucking_path, backup_path=backup_path)
@ -27,7 +28,9 @@ def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq:
class CheckSpamHandler(tornado.web.RequestHandler): class CheckSpamHandler(tornado.web.RequestHandler):
@tornado.gen.coroutine @tornado.gen.coroutine
def post(self): def post(self):
log.info(f"CheckSpamHandler: Received {self.request.body}")
body = json.loads(self.request.body) body = json.loads(self.request.body)
self.set_header("Content-Type", "application/json")
if not 'text' in body: if not 'text' in body:
self.set_status(400) self.set_status(400)
self.write_error(400, body=_json({ "error": "text is not specified" })) self.write_error(400, body=_json({ "error": "text is not specified" }))
@ -39,6 +42,8 @@ def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq:
@tornado.gen.coroutine @tornado.gen.coroutine
def post(self): def post(self):
req = json.loads(self.request.body) req = json.loads(self.request.body)
log.info(f"AdminTrainHandler: Received {self.request.body}")
self.set_header("Content-Type", "application/json")
if not 'is_spam' in req or not 'text' in req: if not 'is_spam' in req or not 'text' in req:
self.set_status(400) self.set_status(400)
self.write(_json({ 'status': 'fail', 'message': 'wrong format' })) self.write(_json({ 'status': 'fail', 'message': 'wrong format' }))
@ -49,11 +54,13 @@ def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq:
@tornado.gen.coroutine @tornado.gen.coroutine
def post(self): def post(self):
global _spam_detector global _spam_detector
log.info(f"AdminRestartHandler: Received {self.request.body}")
auth_header = self.request.headers.get('Authorization') auth_header = self.request.headers.get('Authorization')
if auth_header: if auth_header:
auth_token = auth_header.split(" ")[1] auth_token = auth_header.split(" ")[1]
else: else:
auth_token = '' auth_token = ''
self.set_header("Content-Type", "application/json")
if auth_token == token: if auth_token == token:
_spam_detector = _create_spam_detector() _spam_detector = _create_spam_detector()
else: else:
@ -61,7 +68,7 @@ def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq:
self.write(_json({ 'status': 'fail', 'message': 'Invalid authentication token' })) self.write(_json({ 'status': 'fail', 'message': 'Invalid authentication token' }))
async def start_web_server(): async def start_web_server():
logger.info(f"Starting web server on port {port}") log.info(f"Starting web server on port {port}")
app = tornado.web.Application( app = tornado.web.Application(
[ [
(r"/check-spam", CheckSpamHandler), (r"/check-spam", CheckSpamHandler),