import csv import datetime as dt from src.l.logger import logger import src.model.trainer as trainer from scheduler import Scheduler from os import listdir import time import requests from src.transport.rabbitmq import RabbitMQ from src.transport.train_dto import TrainDto def _does_file_exist_in_dir(path): return len(listdir(path)) > 0 def _listen_to_trainings(csv_file: str, rabbitmq: RabbitMQ, queue: str) -> None: def _callback(ch, method, properties, body): dto = TrainDto.from_json(body) with open(csv_file, "a") as f: writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) writer.writerow([dto.is_spam, dto.text]) rabbitmq.consume(queue_name=queue, callback=_callback, auto_ack=True) def start(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str, rabbitmq: RabbitMQ, queue: str) -> None: logger.info("Starting...") _listen_to_trainings(csv_file=dataset_path, rabbitmq=rabbitmq, queue=queue) def _restart_web_api() -> None: headers = { 'Authorization': f"Bearer {token}" } response = requests.post( f"{web_api_url}/admin/restart", json={}, headers=headers, timeout=3 ) if response.status_code > 399: logger.warn(f"Unable to restart Web API server: {response.status_code}, {response.text}") def _train() -> None: trainer.train(dataset_path=dataset_path, fucking_path=fucking_path, backup_path=models_dir) _restart_web_api() tz_moscow = dt.timezone(dt.timedelta(hours=3)) scheduler = Scheduler(tzinfo=dt.timezone.utc) if not _does_file_exist_in_dir(models_dir): logger.info("Will be updated in 5 seconds...") scheduler.once(dt.timedelta(seconds=5), _train) scheduler.daily(dt.time(hour=3, minute=0, tzinfo=tz_moscow), _train) print(scheduler) while True: scheduler.exec_jobs() time.sleep(1)