53 lines
1.9 KiB
Python
53 lines
1.9 KiB
Python
import csv
|
|
import datetime as dt
|
|
from src.l.logger import logger
|
|
import src.model.trainer as trainer
|
|
from scheduler import Scheduler
|
|
from os import listdir
|
|
import time
|
|
import requests
|
|
from src.transport.rabbitmq import RabbitMQ
|
|
from src.transport.train_dto import TrainDto
|
|
|
|
|
|
def _does_file_exist_in_dir(path):
|
|
return len(listdir(path)) > 0
|
|
|
|
def _listen_to_trainings(csv_file: str, rabbitmq: RabbitMQ, queue: str) -> None:
|
|
def _callback(ch, method, properties, body):
|
|
dto = TrainDto.from_json(body)
|
|
with open(csv_file, "a") as f:
|
|
writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
|
writer.writerow([dto.is_spam, dto.text])
|
|
rabbitmq.consume(queue_name=queue, callback=_callback, auto_ack=True)
|
|
|
|
def start(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str, rabbitmq: RabbitMQ, queue: str) -> None:
|
|
logger.info("Starting...")
|
|
_listen_to_trainings(csv_file=dataset_path, rabbitmq=rabbitmq, queue=queue)
|
|
|
|
def _restart_web_api() -> None:
|
|
headers = { 'Authorization': f"Bearer {token}" }
|
|
response = requests.post(
|
|
f"{web_api_url}/admin/restart",
|
|
json={},
|
|
headers=headers,
|
|
timeout=3
|
|
)
|
|
if response.status_code > 399:
|
|
logger.warn(f"Unable to restart Web API server: {response.status_code}, {response.text}")
|
|
|
|
def _train() -> None:
|
|
trainer.train(dataset_path=dataset_path, fucking_path=fucking_path, backup_path=models_dir)
|
|
_restart_web_api()
|
|
|
|
tz_moscow = dt.timezone(dt.timedelta(hours=3))
|
|
scheduler = Scheduler(tzinfo=dt.timezone.utc)
|
|
if not _does_file_exist_in_dir(models_dir):
|
|
logger.info("Will be updated in 5 seconds...")
|
|
scheduler.once(dt.timedelta(seconds=5), _train)
|
|
scheduler.daily(dt.time(hour=3, minute=0, tzinfo=tz_moscow), _train)
|
|
print(scheduler)
|
|
while True:
|
|
scheduler.exec_jobs()
|
|
time.sleep(1)
|