spam-detector/src/model/updater.py
2024-11-01 23:35:45 +03:00

53 lines
1.9 KiB
Python

import csv
import datetime as dt
from src.l.logger import logger
import src.model.trainer as trainer
from scheduler import Scheduler
from os import listdir
import time
import requests
from src.transport.rabbitmq import RabbitMQ
from src.transport.train_dto import TrainDto
def _does_file_exist_in_dir(path):
return len(listdir(path)) > 0
def _listen_to_trainings(csv_file: str, rabbitmq: RabbitMQ, queue: str) -> None:
def _callback(ch, method, properties, body):
dto = TrainDto.from_json(body)
with open(csv_file, "a") as f:
writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
writer.writerow([dto.is_spam, dto.text])
rabbitmq.consume(queue_name=queue, callback=_callback, auto_ack=True)
def start(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str, rabbitmq: RabbitMQ, queue: str) -> None:
logger.info("Starting...")
_listen_to_trainings(csv_file=dataset_path, rabbitmq=rabbitmq, queue=queue)
def _restart_web_api() -> None:
headers = { 'Authorization': f"Bearer {token}" }
response = requests.post(
f"{web_api_url}/admin/restart",
json={},
headers=headers,
timeout=3
)
if response.status_code > 399:
logger.warn(f"Unable to restart Web API server: {response.status_code}, {response.text}")
def _train() -> None:
trainer.train(dataset_path=dataset_path, fucking_path=fucking_path, backup_path=models_dir)
_restart_web_api()
tz_moscow = dt.timezone(dt.timedelta(hours=3))
scheduler = Scheduler(tzinfo=dt.timezone.utc)
if not _does_file_exist_in_dir(models_dir):
logger.info("Will be updated in 5 seconds...")
scheduler.once(dt.timedelta(seconds=5), _train)
scheduler.daily(dt.time(hour=3, minute=0, tzinfo=tz_moscow), _train)
print(scheduler)
while True:
scheduler.exec_jobs()
time.sleep(1)