Dockerfile for faster startup, model-initer in compose, fixes and updates
This commit is contained in:
parent
3a1b9cc4c8
commit
6d8302ec1f
2
.gitignore
vendored
2
.gitignore
vendored
@ -4,3 +4,5 @@ env/**
|
||||
.idea/**
|
||||
models
|
||||
models/**
|
||||
samples
|
||||
samples/**
|
||||
|
@ -8,7 +8,6 @@ WORKDIR /app
|
||||
|
||||
COPY lib /app/lib
|
||||
COPY requirements.txt /app
|
||||
COPY src /app/src
|
||||
|
||||
|
||||
#######################################################################################################################
|
||||
@ -35,6 +34,8 @@ RUN apk add libgomp libstdc++
|
||||
RUN mkdir /app/nltk_data
|
||||
RUN ln -s /app/nltk_data /root/nltk_data
|
||||
|
||||
COPY src /app/src
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
ARG DATASET
|
||||
|
@ -10,6 +10,8 @@ services:
|
||||
decision-maker:
|
||||
build: ./
|
||||
container_name: spam-detector-decision-maker
|
||||
depends_on:
|
||||
- model-initer
|
||||
environment:
|
||||
- TOKEN=token12345
|
||||
- RABBITMQ_HOST=${RABBITMQ_HOST}
|
||||
@ -30,6 +32,24 @@ services:
|
||||
- spam-detector-internal
|
||||
- rabbitmq
|
||||
|
||||
model-initer:
|
||||
build: ./
|
||||
container_name: spam-detector-model-initer
|
||||
environment:
|
||||
- TOKEN=token12345
|
||||
- RABBITMQ_HOST=${RABBITMQ_HOST}
|
||||
- RABBITMQ_PORT=${RABBITMQ_PORT:-5672}
|
||||
- RABBITMQ_USER=${RABBITMQ_USER}
|
||||
- RABBITMQ_PASS=${RABBITMQ_PASS}
|
||||
- RABBITMQ_QUEUE=${RABBITMQ_QUEUE}
|
||||
entrypoint: [ "python", "-m", "src.app", "-i", "-d", "/app/dataset.csv" ]
|
||||
volumes:
|
||||
- type: bind
|
||||
source: $DATASET
|
||||
target: /app/dataset.csv
|
||||
- ${MODELS_DIR}:/app/models
|
||||
restart: no
|
||||
|
||||
model-updater:
|
||||
build: ./
|
||||
container_name: spam-detector-model-updater
|
||||
|
1134
poetry.lock
generated
Normal file
1134
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
2
rabbitmq/.gitignore
vendored
Normal file
2
rabbitmq/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
data
|
||||
data/**
|
20
rabbitmq/docker-compose.yaml
Normal file
20
rabbitmq/docker-compose.yaml
Normal file
@ -0,0 +1,20 @@
|
||||
networks:
|
||||
rabbitmq:
|
||||
external: true
|
||||
|
||||
services:
|
||||
rabbitmq:
|
||||
image: rabbitmq:3.13.6-management
|
||||
hostname: rabbitmq
|
||||
restart: always
|
||||
environment:
|
||||
- RABBITMQ_DEFAULT_USER=rmuser
|
||||
- RABBITMQ_DEFAULT_PASS=${RABBITMQ_PASS}
|
||||
- RABBITMQ_SERVER_ADDITIONAL_ERL_ARGS=-rabbit log_levels [{connection,error},{default,error}] disk_free_limit 2147483648
|
||||
volumes:
|
||||
- ${CONTAINER}:/var/lib/rabbitmq
|
||||
ports:
|
||||
- 15672:15672
|
||||
- 5672:5672
|
||||
networks:
|
||||
- rabbitmq
|
5
rabbitmq/rebuild.sh
Executable file
5
rabbitmq/rebuild.sh
Executable file
@ -0,0 +1,5 @@
|
||||
#!/bin/bash
|
||||
|
||||
docker-compose up -d --force-recreate --no-deps --build
|
||||
|
||||
|
@ -31,6 +31,8 @@ _rabbitmq_queue = os.getenv("RABBITMQ_QUEUE")
|
||||
def start():
|
||||
if args.init:
|
||||
assert args.dataset is not None, "Dataset is required, run --help"
|
||||
dataset_size = os.path.getsize(args.dataset)
|
||||
print(f"Dataset size, bytes: {dataset_size}")
|
||||
model_trainer.train(args.dataset, fucking_path=_fucking_dir, backup_path=_models_dir)
|
||||
elif args.decision_maker:
|
||||
rabbitmq = RabbitMQ(_rabbitmq_host, _rabbitmq_port, _rabbitmq_user, _rabbitmq_pass)
|
||||
|
@ -1,5 +1,18 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
logging.basicConfig(format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
|
||||
logger = logging.getLogger(__package__)
|
||||
logger.setLevel(logging.INFO)
|
||||
LEVEL = os.getenv('LOG_LEVEL')
|
||||
if LEVEL is None:
|
||||
LEVEL = logging.INFO
|
||||
|
||||
#logging.basicConfig(format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s,%(msecs)d | %(levelname)-8s | %(name)s:%(funcName)s:%(lineno)d | %(message)s',
|
||||
datefmt='%Y-%m-%d:%H:%M:%S',
|
||||
level=LEVEL
|
||||
)
|
||||
|
||||
def logger(pkg: str, level:int = logging.INFO):
|
||||
lgr = logging.getLogger(pkg)
|
||||
lgr.setLevel(level)
|
||||
return lgr
|
||||
|
@ -4,12 +4,14 @@ import shutil
|
||||
from os.path import isfile, join
|
||||
from src.l.logger import logger
|
||||
|
||||
log = logger(__name__)
|
||||
|
||||
|
||||
def _does_file_exist_in_dir(path):
|
||||
return any(isfile(join(path, i)) for i in os.listdir(path))
|
||||
|
||||
def apply_latest(fucking_path: str, backup_path: str) -> None:
|
||||
logger.info("Applying models...")
|
||||
log.info("Applying models...")
|
||||
if _does_file_exist_in_dir(backup_path):
|
||||
files = os.listdir(fucking_path)
|
||||
for file in files:
|
||||
@ -24,15 +26,15 @@ def train(dataset_path: str, fucking_path: str, backup_path: str) -> None:
|
||||
from spam_detector_ai.training.train_models import ModelTrainer
|
||||
|
||||
def _train_model(classifier_type, model_filename, vectoriser_filename, X_train, y_train):
|
||||
logger.info(f'Training {classifier_type}')
|
||||
trainer_ = ModelTrainer(data=None, classifier_type=classifier_type, logger=logger)
|
||||
log.info(f'Training {classifier_type}')
|
||||
trainer_ = ModelTrainer(data=None, classifier_type=classifier_type, logger=log)
|
||||
trainer_.train(X_train, y_train)
|
||||
trainer_.save_model(model_filename, vectoriser_filename)
|
||||
|
||||
logger.info(f"Starting to train using data_path={dataset_path}")
|
||||
log.info(f"Starting to train using data_path={dataset_path}")
|
||||
# Load and preprocess data once
|
||||
# data_path = os.path.join(project_root, 'spam.csv')
|
||||
initial_trainer = ModelTrainer(data_path=dataset_path, logger=logger)
|
||||
initial_trainer = ModelTrainer(data_path=dataset_path, logger=log)
|
||||
processed_data = initial_trainer.preprocess_data_()
|
||||
|
||||
# Split the data once
|
||||
@ -49,15 +51,15 @@ def train(dataset_path: str, fucking_path: str, backup_path: str) -> None:
|
||||
]
|
||||
|
||||
# Train each model with the pre-split data
|
||||
logger.info(f"Train each model with the pre-split data\n")
|
||||
log.info(f"Train each model with the pre-split data\n")
|
||||
for ct, mf, vf in configurations:
|
||||
_train_model(ct, mf, vf, X__train, y__train)
|
||||
|
||||
logger.info("Backing up...")
|
||||
log.info("Backing up...")
|
||||
files = os.listdir(backup_path)
|
||||
for file in files:
|
||||
shutil.rmtree(os.path.join(backup_path, file))
|
||||
shutil.copytree(fucking_path, backup_path,
|
||||
copy_function=shutil.copy2,
|
||||
dirs_exist_ok=True)
|
||||
logger.info("Backing up - done")
|
||||
log.info("Backing up - done")
|
||||
|
@ -10,24 +10,26 @@ from src.transport.rabbitmq import RabbitMQ
|
||||
from src.transport.train_dto import TrainDto
|
||||
import threading
|
||||
|
||||
log = logger(__name__)
|
||||
|
||||
|
||||
def _does_file_exist_in_dir(path):
|
||||
return len(listdir(path)) > 0
|
||||
|
||||
def _listen_to_trainings(csv_file: str, rabbitmq: RabbitMQ, queue: str) -> None:
|
||||
def _callback(ch, method, properties, body):
|
||||
logger.info(f"Message consumed: {body}")
|
||||
log.info(f"Message consumed: {body}")
|
||||
dto = TrainDto.from_json(body)
|
||||
logger.info(f"Message read as DTO: {body}")
|
||||
log.info(f"Message read as DTO: {body}")
|
||||
with open(csv_file, "a") as f:
|
||||
logger.info(f"Writing to dataset: {body}")
|
||||
log.info(f"Writing to dataset: {body}")
|
||||
writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
||||
writer.writerow([dto.spam_interpretation(), dto.text])
|
||||
logger.info(f"Message is done: {body}")
|
||||
log.info(f"Message is done: {body}")
|
||||
rabbitmq.consume(queue_name=queue, callback=_callback, auto_ack=True)
|
||||
|
||||
def _start_listening_to_trainings(csv_file: str, rabbitmq: RabbitMQ, queue: str) -> threading.Thread:
|
||||
logger.info("Starting listening to trainings")
|
||||
log.info("Starting listening to trainings")
|
||||
t = threading.Thread(target=_listen_to_trainings, args=(csv_file, rabbitmq, queue), daemon=True)
|
||||
t.start()
|
||||
return t
|
||||
@ -42,7 +44,7 @@ def _start_scheduling(fucking_path: str, models_dir: str, dataset_path: str, web
|
||||
timeout=3
|
||||
)
|
||||
if response.status_code > 399:
|
||||
logger.warn(f"Unable to restart Web API server: {response.status_code}, {response.text}")
|
||||
log.warn(f"Unable to restart Web API server: {response.status_code}, {response.text}")
|
||||
|
||||
def _train() -> None:
|
||||
trainer.train(dataset_path=dataset_path, fucking_path=fucking_path, backup_path=models_dir)
|
||||
@ -51,7 +53,7 @@ def _start_scheduling(fucking_path: str, models_dir: str, dataset_path: str, web
|
||||
tz_moscow = dt.timezone(dt.timedelta(hours=3))
|
||||
scheduler = Scheduler(tzinfo=dt.timezone.utc)
|
||||
if not _does_file_exist_in_dir(models_dir):
|
||||
logger.info("Will be updated in 5 seconds...")
|
||||
log.info("Will be updated in 5 seconds...")
|
||||
scheduler.once(dt.timedelta(seconds=5), _train)
|
||||
scheduler.daily(dt.time(hour=3, minute=0, tzinfo=tz_moscow), _train)
|
||||
print(scheduler)
|
||||
@ -60,13 +62,13 @@ def _start_scheduling(fucking_path: str, models_dir: str, dataset_path: str, web
|
||||
time.sleep(1)
|
||||
|
||||
def _start_scheduling_in_thread(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str) -> threading.Thread:
|
||||
logger.info("Starting scheduling in thread")
|
||||
log.info("Starting scheduling in thread")
|
||||
t = threading.Thread(target=_start_scheduling, args=(fucking_path, models_dir, dataset_path, web_api_url, token), daemon=True)
|
||||
t.start()
|
||||
return t
|
||||
|
||||
def start(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str, rabbitmq: RabbitMQ, queue: str) -> None:
|
||||
logger.info("Starting...")
|
||||
log.info("Starting...")
|
||||
t1 = _start_listening_to_trainings(csv_file=dataset_path, rabbitmq=rabbitmq, queue=queue)
|
||||
t2 = _start_scheduling_in_thread(fucking_path=fucking_path, models_dir=models_dir, dataset_path=dataset_path, web_api_url=web_api_url, token=token)
|
||||
t1.join()
|
||||
|
@ -3,6 +3,8 @@ import pika
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from src.l.logger import logger
|
||||
|
||||
log = logger(__name__)
|
||||
|
||||
|
||||
class RabbitMQ:
|
||||
def __init__(self, host: str, port: int, user: str, passwd: str):
|
||||
@ -47,4 +49,4 @@ class RabbitMQ:
|
||||
),
|
||||
mandatory=True
|
||||
)
|
||||
logger.info(f"Sent message to queue {queue_name}: {message}")
|
||||
log.info(f"Sent message to queue {queue_name}: {message}")
|
||||
|
@ -8,6 +8,7 @@ from src.transport.rabbitmq import RabbitMQ
|
||||
from src.transport.train_dto import TrainDto
|
||||
|
||||
_spam_detector = None
|
||||
log = logger(__name__)
|
||||
|
||||
|
||||
def _json(data) -> str:
|
||||
@ -15,7 +16,7 @@ def _json(data) -> str:
|
||||
|
||||
def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq: RabbitMQ, queue: str) -> None:
|
||||
global _spam_detector
|
||||
logger.info("Starting...")
|
||||
log.info("Starting...")
|
||||
|
||||
def _create_spam_detector():
|
||||
model_trainer.apply_latest(fucking_path=fucking_path, backup_path=backup_path)
|
||||
@ -27,7 +28,9 @@ def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq:
|
||||
class CheckSpamHandler(tornado.web.RequestHandler):
|
||||
@tornado.gen.coroutine
|
||||
def post(self):
|
||||
log.info(f"CheckSpamHandler: Received {self.request.body}")
|
||||
body = json.loads(self.request.body)
|
||||
self.set_header("Content-Type", "application/json")
|
||||
if not 'text' in body:
|
||||
self.set_status(400)
|
||||
self.write_error(400, body=_json({ "error": "text is not specified" }))
|
||||
@ -39,6 +42,8 @@ def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq:
|
||||
@tornado.gen.coroutine
|
||||
def post(self):
|
||||
req = json.loads(self.request.body)
|
||||
log.info(f"AdminTrainHandler: Received {self.request.body}")
|
||||
self.set_header("Content-Type", "application/json")
|
||||
if not 'is_spam' in req or not 'text' in req:
|
||||
self.set_status(400)
|
||||
self.write(_json({ 'status': 'fail', 'message': 'wrong format' }))
|
||||
@ -49,11 +54,13 @@ def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq:
|
||||
@tornado.gen.coroutine
|
||||
def post(self):
|
||||
global _spam_detector
|
||||
log.info(f"AdminRestartHandler: Received {self.request.body}")
|
||||
auth_header = self.request.headers.get('Authorization')
|
||||
if auth_header:
|
||||
auth_token = auth_header.split(" ")[1]
|
||||
else:
|
||||
auth_token = ''
|
||||
self.set_header("Content-Type", "application/json")
|
||||
if auth_token == token:
|
||||
_spam_detector = _create_spam_detector()
|
||||
else:
|
||||
@ -61,7 +68,7 @@ def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq:
|
||||
self.write(_json({ 'status': 'fail', 'message': 'Invalid authentication token' }))
|
||||
|
||||
async def start_web_server():
|
||||
logger.info(f"Starting web server on port {port}")
|
||||
log.info(f"Starting web server on port {port}")
|
||||
app = tornado.web.Application(
|
||||
[
|
||||
(r"/check-spam", CheckSpamHandler),
|
||||
|
Loading…
x
Reference in New Issue
Block a user