Dockerfile for faster startup, model-initer in compose, fixes and updates
This commit is contained in:
parent
3a1b9cc4c8
commit
6d8302ec1f
2
.gitignore
vendored
2
.gitignore
vendored
@ -4,3 +4,5 @@ env/**
|
|||||||
.idea/**
|
.idea/**
|
||||||
models
|
models
|
||||||
models/**
|
models/**
|
||||||
|
samples
|
||||||
|
samples/**
|
||||||
|
@ -8,7 +8,6 @@ WORKDIR /app
|
|||||||
|
|
||||||
COPY lib /app/lib
|
COPY lib /app/lib
|
||||||
COPY requirements.txt /app
|
COPY requirements.txt /app
|
||||||
COPY src /app/src
|
|
||||||
|
|
||||||
|
|
||||||
#######################################################################################################################
|
#######################################################################################################################
|
||||||
@ -35,6 +34,8 @@ RUN apk add libgomp libstdc++
|
|||||||
RUN mkdir /app/nltk_data
|
RUN mkdir /app/nltk_data
|
||||||
RUN ln -s /app/nltk_data /root/nltk_data
|
RUN ln -s /app/nltk_data /root/nltk_data
|
||||||
|
|
||||||
|
COPY src /app/src
|
||||||
|
|
||||||
ENV PYTHONPATH=/app
|
ENV PYTHONPATH=/app
|
||||||
|
|
||||||
ARG DATASET
|
ARG DATASET
|
||||||
|
@ -10,6 +10,8 @@ services:
|
|||||||
decision-maker:
|
decision-maker:
|
||||||
build: ./
|
build: ./
|
||||||
container_name: spam-detector-decision-maker
|
container_name: spam-detector-decision-maker
|
||||||
|
depends_on:
|
||||||
|
- model-initer
|
||||||
environment:
|
environment:
|
||||||
- TOKEN=token12345
|
- TOKEN=token12345
|
||||||
- RABBITMQ_HOST=${RABBITMQ_HOST}
|
- RABBITMQ_HOST=${RABBITMQ_HOST}
|
||||||
@ -30,6 +32,24 @@ services:
|
|||||||
- spam-detector-internal
|
- spam-detector-internal
|
||||||
- rabbitmq
|
- rabbitmq
|
||||||
|
|
||||||
|
model-initer:
|
||||||
|
build: ./
|
||||||
|
container_name: spam-detector-model-initer
|
||||||
|
environment:
|
||||||
|
- TOKEN=token12345
|
||||||
|
- RABBITMQ_HOST=${RABBITMQ_HOST}
|
||||||
|
- RABBITMQ_PORT=${RABBITMQ_PORT:-5672}
|
||||||
|
- RABBITMQ_USER=${RABBITMQ_USER}
|
||||||
|
- RABBITMQ_PASS=${RABBITMQ_PASS}
|
||||||
|
- RABBITMQ_QUEUE=${RABBITMQ_QUEUE}
|
||||||
|
entrypoint: [ "python", "-m", "src.app", "-i", "-d", "/app/dataset.csv" ]
|
||||||
|
volumes:
|
||||||
|
- type: bind
|
||||||
|
source: $DATASET
|
||||||
|
target: /app/dataset.csv
|
||||||
|
- ${MODELS_DIR}:/app/models
|
||||||
|
restart: no
|
||||||
|
|
||||||
model-updater:
|
model-updater:
|
||||||
build: ./
|
build: ./
|
||||||
container_name: spam-detector-model-updater
|
container_name: spam-detector-model-updater
|
||||||
|
1134
poetry.lock
generated
Normal file
1134
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
2
rabbitmq/.gitignore
vendored
Normal file
2
rabbitmq/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
data
|
||||||
|
data/**
|
20
rabbitmq/docker-compose.yaml
Normal file
20
rabbitmq/docker-compose.yaml
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
networks:
|
||||||
|
rabbitmq:
|
||||||
|
external: true
|
||||||
|
|
||||||
|
services:
|
||||||
|
rabbitmq:
|
||||||
|
image: rabbitmq:3.13.6-management
|
||||||
|
hostname: rabbitmq
|
||||||
|
restart: always
|
||||||
|
environment:
|
||||||
|
- RABBITMQ_DEFAULT_USER=rmuser
|
||||||
|
- RABBITMQ_DEFAULT_PASS=${RABBITMQ_PASS}
|
||||||
|
- RABBITMQ_SERVER_ADDITIONAL_ERL_ARGS=-rabbit log_levels [{connection,error},{default,error}] disk_free_limit 2147483648
|
||||||
|
volumes:
|
||||||
|
- ${CONTAINER}:/var/lib/rabbitmq
|
||||||
|
ports:
|
||||||
|
- 15672:15672
|
||||||
|
- 5672:5672
|
||||||
|
networks:
|
||||||
|
- rabbitmq
|
5
rabbitmq/rebuild.sh
Executable file
5
rabbitmq/rebuild.sh
Executable file
@ -0,0 +1,5 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
docker-compose up -d --force-recreate --no-deps --build
|
||||||
|
|
||||||
|
|
@ -31,6 +31,8 @@ _rabbitmq_queue = os.getenv("RABBITMQ_QUEUE")
|
|||||||
def start():
|
def start():
|
||||||
if args.init:
|
if args.init:
|
||||||
assert args.dataset is not None, "Dataset is required, run --help"
|
assert args.dataset is not None, "Dataset is required, run --help"
|
||||||
|
dataset_size = os.path.getsize(args.dataset)
|
||||||
|
print(f"Dataset size, bytes: {dataset_size}")
|
||||||
model_trainer.train(args.dataset, fucking_path=_fucking_dir, backup_path=_models_dir)
|
model_trainer.train(args.dataset, fucking_path=_fucking_dir, backup_path=_models_dir)
|
||||||
elif args.decision_maker:
|
elif args.decision_maker:
|
||||||
rabbitmq = RabbitMQ(_rabbitmq_host, _rabbitmq_port, _rabbitmq_user, _rabbitmq_pass)
|
rabbitmq = RabbitMQ(_rabbitmq_host, _rabbitmq_port, _rabbitmq_user, _rabbitmq_pass)
|
||||||
|
@ -1,5 +1,18 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
logging.basicConfig(format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
|
LEVEL = os.getenv('LOG_LEVEL')
|
||||||
logger = logging.getLogger(__package__)
|
if LEVEL is None:
|
||||||
logger.setLevel(logging.INFO)
|
LEVEL = logging.INFO
|
||||||
|
|
||||||
|
#logging.basicConfig(format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
|
||||||
|
logging.basicConfig(
|
||||||
|
format='%(asctime)s,%(msecs)d | %(levelname)-8s | %(name)s:%(funcName)s:%(lineno)d | %(message)s',
|
||||||
|
datefmt='%Y-%m-%d:%H:%M:%S',
|
||||||
|
level=LEVEL
|
||||||
|
)
|
||||||
|
|
||||||
|
def logger(pkg: str, level:int = logging.INFO):
|
||||||
|
lgr = logging.getLogger(pkg)
|
||||||
|
lgr.setLevel(level)
|
||||||
|
return lgr
|
||||||
|
@ -4,12 +4,14 @@ import shutil
|
|||||||
from os.path import isfile, join
|
from os.path import isfile, join
|
||||||
from src.l.logger import logger
|
from src.l.logger import logger
|
||||||
|
|
||||||
|
log = logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _does_file_exist_in_dir(path):
|
def _does_file_exist_in_dir(path):
|
||||||
return any(isfile(join(path, i)) for i in os.listdir(path))
|
return any(isfile(join(path, i)) for i in os.listdir(path))
|
||||||
|
|
||||||
def apply_latest(fucking_path: str, backup_path: str) -> None:
|
def apply_latest(fucking_path: str, backup_path: str) -> None:
|
||||||
logger.info("Applying models...")
|
log.info("Applying models...")
|
||||||
if _does_file_exist_in_dir(backup_path):
|
if _does_file_exist_in_dir(backup_path):
|
||||||
files = os.listdir(fucking_path)
|
files = os.listdir(fucking_path)
|
||||||
for file in files:
|
for file in files:
|
||||||
@ -24,15 +26,15 @@ def train(dataset_path: str, fucking_path: str, backup_path: str) -> None:
|
|||||||
from spam_detector_ai.training.train_models import ModelTrainer
|
from spam_detector_ai.training.train_models import ModelTrainer
|
||||||
|
|
||||||
def _train_model(classifier_type, model_filename, vectoriser_filename, X_train, y_train):
|
def _train_model(classifier_type, model_filename, vectoriser_filename, X_train, y_train):
|
||||||
logger.info(f'Training {classifier_type}')
|
log.info(f'Training {classifier_type}')
|
||||||
trainer_ = ModelTrainer(data=None, classifier_type=classifier_type, logger=logger)
|
trainer_ = ModelTrainer(data=None, classifier_type=classifier_type, logger=log)
|
||||||
trainer_.train(X_train, y_train)
|
trainer_.train(X_train, y_train)
|
||||||
trainer_.save_model(model_filename, vectoriser_filename)
|
trainer_.save_model(model_filename, vectoriser_filename)
|
||||||
|
|
||||||
logger.info(f"Starting to train using data_path={dataset_path}")
|
log.info(f"Starting to train using data_path={dataset_path}")
|
||||||
# Load and preprocess data once
|
# Load and preprocess data once
|
||||||
# data_path = os.path.join(project_root, 'spam.csv')
|
# data_path = os.path.join(project_root, 'spam.csv')
|
||||||
initial_trainer = ModelTrainer(data_path=dataset_path, logger=logger)
|
initial_trainer = ModelTrainer(data_path=dataset_path, logger=log)
|
||||||
processed_data = initial_trainer.preprocess_data_()
|
processed_data = initial_trainer.preprocess_data_()
|
||||||
|
|
||||||
# Split the data once
|
# Split the data once
|
||||||
@ -49,15 +51,15 @@ def train(dataset_path: str, fucking_path: str, backup_path: str) -> None:
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Train each model with the pre-split data
|
# Train each model with the pre-split data
|
||||||
logger.info(f"Train each model with the pre-split data\n")
|
log.info(f"Train each model with the pre-split data\n")
|
||||||
for ct, mf, vf in configurations:
|
for ct, mf, vf in configurations:
|
||||||
_train_model(ct, mf, vf, X__train, y__train)
|
_train_model(ct, mf, vf, X__train, y__train)
|
||||||
|
|
||||||
logger.info("Backing up...")
|
log.info("Backing up...")
|
||||||
files = os.listdir(backup_path)
|
files = os.listdir(backup_path)
|
||||||
for file in files:
|
for file in files:
|
||||||
shutil.rmtree(os.path.join(backup_path, file))
|
shutil.rmtree(os.path.join(backup_path, file))
|
||||||
shutil.copytree(fucking_path, backup_path,
|
shutil.copytree(fucking_path, backup_path,
|
||||||
copy_function=shutil.copy2,
|
copy_function=shutil.copy2,
|
||||||
dirs_exist_ok=True)
|
dirs_exist_ok=True)
|
||||||
logger.info("Backing up - done")
|
log.info("Backing up - done")
|
||||||
|
@ -10,24 +10,26 @@ from src.transport.rabbitmq import RabbitMQ
|
|||||||
from src.transport.train_dto import TrainDto
|
from src.transport.train_dto import TrainDto
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
log = logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _does_file_exist_in_dir(path):
|
def _does_file_exist_in_dir(path):
|
||||||
return len(listdir(path)) > 0
|
return len(listdir(path)) > 0
|
||||||
|
|
||||||
def _listen_to_trainings(csv_file: str, rabbitmq: RabbitMQ, queue: str) -> None:
|
def _listen_to_trainings(csv_file: str, rabbitmq: RabbitMQ, queue: str) -> None:
|
||||||
def _callback(ch, method, properties, body):
|
def _callback(ch, method, properties, body):
|
||||||
logger.info(f"Message consumed: {body}")
|
log.info(f"Message consumed: {body}")
|
||||||
dto = TrainDto.from_json(body)
|
dto = TrainDto.from_json(body)
|
||||||
logger.info(f"Message read as DTO: {body}")
|
log.info(f"Message read as DTO: {body}")
|
||||||
with open(csv_file, "a") as f:
|
with open(csv_file, "a") as f:
|
||||||
logger.info(f"Writing to dataset: {body}")
|
log.info(f"Writing to dataset: {body}")
|
||||||
writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
||||||
writer.writerow([dto.spam_interpretation(), dto.text])
|
writer.writerow([dto.spam_interpretation(), dto.text])
|
||||||
logger.info(f"Message is done: {body}")
|
log.info(f"Message is done: {body}")
|
||||||
rabbitmq.consume(queue_name=queue, callback=_callback, auto_ack=True)
|
rabbitmq.consume(queue_name=queue, callback=_callback, auto_ack=True)
|
||||||
|
|
||||||
def _start_listening_to_trainings(csv_file: str, rabbitmq: RabbitMQ, queue: str) -> threading.Thread:
|
def _start_listening_to_trainings(csv_file: str, rabbitmq: RabbitMQ, queue: str) -> threading.Thread:
|
||||||
logger.info("Starting listening to trainings")
|
log.info("Starting listening to trainings")
|
||||||
t = threading.Thread(target=_listen_to_trainings, args=(csv_file, rabbitmq, queue), daemon=True)
|
t = threading.Thread(target=_listen_to_trainings, args=(csv_file, rabbitmq, queue), daemon=True)
|
||||||
t.start()
|
t.start()
|
||||||
return t
|
return t
|
||||||
@ -42,7 +44,7 @@ def _start_scheduling(fucking_path: str, models_dir: str, dataset_path: str, web
|
|||||||
timeout=3
|
timeout=3
|
||||||
)
|
)
|
||||||
if response.status_code > 399:
|
if response.status_code > 399:
|
||||||
logger.warn(f"Unable to restart Web API server: {response.status_code}, {response.text}")
|
log.warn(f"Unable to restart Web API server: {response.status_code}, {response.text}")
|
||||||
|
|
||||||
def _train() -> None:
|
def _train() -> None:
|
||||||
trainer.train(dataset_path=dataset_path, fucking_path=fucking_path, backup_path=models_dir)
|
trainer.train(dataset_path=dataset_path, fucking_path=fucking_path, backup_path=models_dir)
|
||||||
@ -51,7 +53,7 @@ def _start_scheduling(fucking_path: str, models_dir: str, dataset_path: str, web
|
|||||||
tz_moscow = dt.timezone(dt.timedelta(hours=3))
|
tz_moscow = dt.timezone(dt.timedelta(hours=3))
|
||||||
scheduler = Scheduler(tzinfo=dt.timezone.utc)
|
scheduler = Scheduler(tzinfo=dt.timezone.utc)
|
||||||
if not _does_file_exist_in_dir(models_dir):
|
if not _does_file_exist_in_dir(models_dir):
|
||||||
logger.info("Will be updated in 5 seconds...")
|
log.info("Will be updated in 5 seconds...")
|
||||||
scheduler.once(dt.timedelta(seconds=5), _train)
|
scheduler.once(dt.timedelta(seconds=5), _train)
|
||||||
scheduler.daily(dt.time(hour=3, minute=0, tzinfo=tz_moscow), _train)
|
scheduler.daily(dt.time(hour=3, minute=0, tzinfo=tz_moscow), _train)
|
||||||
print(scheduler)
|
print(scheduler)
|
||||||
@ -60,13 +62,13 @@ def _start_scheduling(fucking_path: str, models_dir: str, dataset_path: str, web
|
|||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
def _start_scheduling_in_thread(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str) -> threading.Thread:
|
def _start_scheduling_in_thread(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str) -> threading.Thread:
|
||||||
logger.info("Starting scheduling in thread")
|
log.info("Starting scheduling in thread")
|
||||||
t = threading.Thread(target=_start_scheduling, args=(fucking_path, models_dir, dataset_path, web_api_url, token), daemon=True)
|
t = threading.Thread(target=_start_scheduling, args=(fucking_path, models_dir, dataset_path, web_api_url, token), daemon=True)
|
||||||
t.start()
|
t.start()
|
||||||
return t
|
return t
|
||||||
|
|
||||||
def start(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str, rabbitmq: RabbitMQ, queue: str) -> None:
|
def start(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str, rabbitmq: RabbitMQ, queue: str) -> None:
|
||||||
logger.info("Starting...")
|
log.info("Starting...")
|
||||||
t1 = _start_listening_to_trainings(csv_file=dataset_path, rabbitmq=rabbitmq, queue=queue)
|
t1 = _start_listening_to_trainings(csv_file=dataset_path, rabbitmq=rabbitmq, queue=queue)
|
||||||
t2 = _start_scheduling_in_thread(fucking_path=fucking_path, models_dir=models_dir, dataset_path=dataset_path, web_api_url=web_api_url, token=token)
|
t2 = _start_scheduling_in_thread(fucking_path=fucking_path, models_dir=models_dir, dataset_path=dataset_path, web_api_url=web_api_url, token=token)
|
||||||
t1.join()
|
t1.join()
|
||||||
|
@ -3,6 +3,8 @@ import pika
|
|||||||
from pika.adapters.blocking_connection import BlockingChannel
|
from pika.adapters.blocking_connection import BlockingChannel
|
||||||
from src.l.logger import logger
|
from src.l.logger import logger
|
||||||
|
|
||||||
|
log = logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RabbitMQ:
|
class RabbitMQ:
|
||||||
def __init__(self, host: str, port: int, user: str, passwd: str):
|
def __init__(self, host: str, port: int, user: str, passwd: str):
|
||||||
@ -47,4 +49,4 @@ class RabbitMQ:
|
|||||||
),
|
),
|
||||||
mandatory=True
|
mandatory=True
|
||||||
)
|
)
|
||||||
logger.info(f"Sent message to queue {queue_name}: {message}")
|
log.info(f"Sent message to queue {queue_name}: {message}")
|
||||||
|
@ -8,6 +8,7 @@ from src.transport.rabbitmq import RabbitMQ
|
|||||||
from src.transport.train_dto import TrainDto
|
from src.transport.train_dto import TrainDto
|
||||||
|
|
||||||
_spam_detector = None
|
_spam_detector = None
|
||||||
|
log = logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _json(data) -> str:
|
def _json(data) -> str:
|
||||||
@ -15,7 +16,7 @@ def _json(data) -> str:
|
|||||||
|
|
||||||
def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq: RabbitMQ, queue: str) -> None:
|
def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq: RabbitMQ, queue: str) -> None:
|
||||||
global _spam_detector
|
global _spam_detector
|
||||||
logger.info("Starting...")
|
log.info("Starting...")
|
||||||
|
|
||||||
def _create_spam_detector():
|
def _create_spam_detector():
|
||||||
model_trainer.apply_latest(fucking_path=fucking_path, backup_path=backup_path)
|
model_trainer.apply_latest(fucking_path=fucking_path, backup_path=backup_path)
|
||||||
@ -27,7 +28,9 @@ def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq:
|
|||||||
class CheckSpamHandler(tornado.web.RequestHandler):
|
class CheckSpamHandler(tornado.web.RequestHandler):
|
||||||
@tornado.gen.coroutine
|
@tornado.gen.coroutine
|
||||||
def post(self):
|
def post(self):
|
||||||
|
log.info(f"CheckSpamHandler: Received {self.request.body}")
|
||||||
body = json.loads(self.request.body)
|
body = json.loads(self.request.body)
|
||||||
|
self.set_header("Content-Type", "application/json")
|
||||||
if not 'text' in body:
|
if not 'text' in body:
|
||||||
self.set_status(400)
|
self.set_status(400)
|
||||||
self.write_error(400, body=_json({ "error": "text is not specified" }))
|
self.write_error(400, body=_json({ "error": "text is not specified" }))
|
||||||
@ -39,6 +42,8 @@ def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq:
|
|||||||
@tornado.gen.coroutine
|
@tornado.gen.coroutine
|
||||||
def post(self):
|
def post(self):
|
||||||
req = json.loads(self.request.body)
|
req = json.loads(self.request.body)
|
||||||
|
log.info(f"AdminTrainHandler: Received {self.request.body}")
|
||||||
|
self.set_header("Content-Type", "application/json")
|
||||||
if not 'is_spam' in req or not 'text' in req:
|
if not 'is_spam' in req or not 'text' in req:
|
||||||
self.set_status(400)
|
self.set_status(400)
|
||||||
self.write(_json({ 'status': 'fail', 'message': 'wrong format' }))
|
self.write(_json({ 'status': 'fail', 'message': 'wrong format' }))
|
||||||
@ -49,11 +54,13 @@ def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq:
|
|||||||
@tornado.gen.coroutine
|
@tornado.gen.coroutine
|
||||||
def post(self):
|
def post(self):
|
||||||
global _spam_detector
|
global _spam_detector
|
||||||
|
log.info(f"AdminRestartHandler: Received {self.request.body}")
|
||||||
auth_header = self.request.headers.get('Authorization')
|
auth_header = self.request.headers.get('Authorization')
|
||||||
if auth_header:
|
if auth_header:
|
||||||
auth_token = auth_header.split(" ")[1]
|
auth_token = auth_header.split(" ")[1]
|
||||||
else:
|
else:
|
||||||
auth_token = ''
|
auth_token = ''
|
||||||
|
self.set_header("Content-Type", "application/json")
|
||||||
if auth_token == token:
|
if auth_token == token:
|
||||||
_spam_detector = _create_spam_detector()
|
_spam_detector = _create_spam_detector()
|
||||||
else:
|
else:
|
||||||
@ -61,7 +68,7 @@ def start(port: int, token: str, fucking_path: str, backup_path: str, rabbitmq:
|
|||||||
self.write(_json({ 'status': 'fail', 'message': 'Invalid authentication token' }))
|
self.write(_json({ 'status': 'fail', 'message': 'Invalid authentication token' }))
|
||||||
|
|
||||||
async def start_web_server():
|
async def start_web_server():
|
||||||
logger.info(f"Starting web server on port {port}")
|
log.info(f"Starting web server on port {port}")
|
||||||
app = tornado.web.Application(
|
app = tornado.web.Application(
|
||||||
[
|
[
|
||||||
(r"/check-spam", CheckSpamHandler),
|
(r"/check-spam", CheckSpamHandler),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user