working web server and updater

This commit is contained in:
bvn13 2024-11-01 17:06:38 +03:00
parent 3ad7a72515
commit c6ab8791ca
6 changed files with 5653 additions and 18 deletions

View File

@ -1,5 +1,6 @@
# https://stackoverflow.com/a/57374374/2798461 # https://stackoverflow.com/a/57374374/2798461
#######################################################################################################################
FROM python:3.10-alpine AS dist FROM python:3.10-alpine AS dist
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
@ -10,6 +11,7 @@ COPY requirements.txt /app
COPY src /app/src COPY src /app/src
#######################################################################################################################
FROM dist AS installation FROM dist AS installation
RUN apk update RUN apk update
@ -18,12 +20,14 @@ RUN pip install --no-index --find-links /app/lib nvidia-nccl-cu12
RUN apk add curl bash gcc g++ cmake make musl-dev RUN apk add curl bash gcc g++ cmake make musl-dev
#######################################################################################################################
FROM installation AS building FROM installation AS building
RUN pip install -r requirements.txt RUN pip install -r requirements.txt
RUN apk del gcc g++ cmake make musl-dev RUN apk del gcc g++ cmake make musl-dev
#######################################################################################################################
FROM building AS runner FROM building AS runner
RUN apk add libgomp libstdc++ RUN apk add libgomp libstdc++
@ -35,12 +39,14 @@ ENV PYTHONPATH=/app
ARG DATASET ARG DATASET
ARG PORT ARG PORT
ARG WEB_API_URL
ENV PORT=${PORT} \ ENV PORT=${PORT} \
DATASET=${DATASET} \ DATASET=${DATASET} \
WORKING_DIR=/app/nltk_data \ WORKING_DIR=/app/nltk_data \
FUCKING_DIR=/usr/local/lib/python3.10/site-packages/spam_detector_ai/models \ FUCKING_DIR=/usr/local/lib/python3.10/site-packages/spam_detector_ai/models \
MODELS_DIR=/app/models MODELS_DIR=/app/models \
WEB_API_URL=${WEB_API_URL}
RUN python3 -m src.preparer RUN python3 -m src.preparer

View File

@ -1,3 +1,8 @@
networks:
spam-detector-internal:
driver: bridge
services: services:
decision-maker: decision-maker:
@ -5,22 +10,31 @@ services:
container_name: spam-detector-decision-maker container_name: spam-detector-decision-maker
environment: environment:
- PORT=${PORT:-8080} - PORT=${PORT:-8080}
- TOKEN=token12345
entrypoint: [ "python", "-m", "src.app", "-m" ] entrypoint: [ "python", "-m", "src.app", "-m" ]
volumes: volumes:
- type: bind - type: bind
source: $DATASET source: $DATASET
target: /app/dataset.csv target: /app/dataset.csv
- ${MODELS_DIR}:/app/models - ${MODELS_DIR}:/app/models
restart: always
ports:
- "8080:${PORT:-8080}"
networks:
- spam-detector-internal
model-updater: model-updater:
build: ./ build: ./
container_name: spam-detector-model-updater container_name: spam-detector-model-updater
environment: environment:
- NONE=1 - WEB_API_URL=http://spam-detector-decision-maker:${PORT:-8080}
# entrypoint: [ "python", "-m", "src.app", "-u", "-d", "/app/dataset.csv" ] - TOKEN=token12345
entrypoint: [ "bash", "-c", "while true; do sleep 1; done" ] entrypoint: [ "python", "-m", "src.app", "-u", "-d", "/app/dataset.csv" ]
# entrypoint: [ "bash", "-c", "while true; do sleep 1; done" ]
volumes: volumes:
- type: bind - type: bind
source: $DATASET source: $DATASET
target: /app/dataset.csv target: /app/dataset.csv
- ${MODELS_DIR}:/app/models/ - ${MODELS_DIR}:/app/models/
networks:
- spam-detector-internal

5575
spam.csv Normal file

File diff suppressed because it is too large Load Diff

View File

@ -19,6 +19,9 @@ assert (args.init is not None
_port = 8080 if os.getenv('PORT') is None else os.getenv('PORT') _port = 8080 if os.getenv('PORT') is None else os.getenv('PORT')
_models_dir = os.getenv("MODELS_DIR") _models_dir = os.getenv("MODELS_DIR")
_fucking_dir = os.getenv("FUCKING_DIR") _fucking_dir = os.getenv("FUCKING_DIR")
_web_api_url = os.getenv("WEB_API_URL")
_token = os.getenv("TOKEN")
def start(): def start():
if args.init: if args.init:
@ -26,10 +29,10 @@ def start():
model_trainer.train(args.dataset, fucking_path=_fucking_dir, backup_path=_models_dir) model_trainer.train(args.dataset, fucking_path=_fucking_dir, backup_path=_models_dir)
elif args.decision_maker: elif args.decision_maker:
model_trainer.apply_latest(fucking_path=_fucking_dir, backup_path=_models_dir) model_trainer.apply_latest(fucking_path=_fucking_dir, backup_path=_models_dir)
web_server.start(port=_port) web_server.start(port=_port, token=_token, fucking_path=_fucking_dir, backup_path=_models_dir)
elif args.model_updater: elif args.model_updater:
assert args.dataset is not None, "Dataset is required, run --help" assert args.dataset is not None, "Dataset is required, run --help"
model_updater.start(fucking_path=_fucking_dir, models_dir=_models_dir, dataset_path=args.dataset) model_updater.start(fucking_path=_fucking_dir, models_dir=_models_dir, dataset_path=args.dataset, web_api_url=_web_api_url, token=_token)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -4,23 +4,36 @@ import src.model.trainer as trainer
from scheduler import Scheduler from scheduler import Scheduler
from os import listdir from os import listdir
import time import time
import requests
def _does_file_exist_in_dir(path): def _does_file_exist_in_dir(path):
return len(listdir(path)) > 0 return len(listdir(path)) > 0
def start(fucking_path: str, models_dir: str, dataset_path: str) -> None: def start(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str) -> None:
logger.info("Starting...") logger.info("Starting...")
def train() -> None: def _restart_web_api() -> None:
trainer.train(dataset_path, fucking_path=fucking_path, backup_path=models_dir) headers = { 'Authorization': f"Bearer {token}" }
response = requests.post(
f"{web_api_url}/admin/restart",
json={},
headers=headers,
timeout=3
)
if response.status_code > 399:
logger.warn(f"Unable to restart Web API server: {response.status_code}, {response.text}")
def _train() -> None:
trainer.train(dataset_path=dataset_path, fucking_path=fucking_path, backup_path=models_dir)
_restart_web_api()
tz_moscow = dt.timezone(dt.timedelta(hours=3)) tz_moscow = dt.timezone(dt.timedelta(hours=3))
scheduler = Scheduler(tzinfo=dt.timezone.utc) scheduler = Scheduler(tzinfo=dt.timezone.utc)
if not _does_file_exist_in_dir(models_dir): if not _does_file_exist_in_dir(models_dir):
logger.info("Will be updated in 5 seconds...") logger.info("Will be updated in 5 seconds...")
scheduler.once(dt.timedelta(seconds=5), train) scheduler.once(dt.timedelta(seconds=5), _train)
scheduler.daily(dt.time(hour=3, minute=0, tzinfo=tz_moscow), train) scheduler.daily(dt.time(hour=3, minute=0, tzinfo=tz_moscow), _train)
print(scheduler) print(scheduler)
while True: while True:
scheduler.exec_jobs() scheduler.exec_jobs()

View File

@ -2,24 +2,30 @@ import asyncio
import json import json
import os import os
import tornado import tornado
from spam_detector_ai.prediction.predict import VotingSpamDetector import src.model.trainer as model_trainer
from src.l.logger import logger from src.l.logger import logger
_spam_detector = VotingSpamDetector() _spam_detector = None
def _json(data) -> str: def _json(data) -> str:
return json.dumps(data) return json.dumps(data)
def start(port: int) -> None: def start(port: int, token: str, fucking_path: str, backup_path: str) -> None:
global _spam_detector
logger.info("Starting...") logger.info("Starting...")
class CheckSpamHandler(tornado.web.RequestHandler): def _create_spam_detector():
def set_default_headers(self): model_trainer.apply_latest(fucking_path=fucking_path, backup_path=backup_path)
self.set_header("Access-Control-Allow-Origin", "*") from spam_detector_ai.prediction.predict import VotingSpamDetector
return VotingSpamDetector()
def get(self): _spam_detector = _create_spam_detector()
class CheckSpamHandler(tornado.web.RequestHandler):
@tornado.gen.coroutine
def post(self):
body = json.loads(self.request.body) body = json.loads(self.request.body)
if not 'text' in body: if not 'text' in body:
self.write_error(400, body=_json({"error": "text is not specified"})) self.write_error(400, body=_json({"error": "text is not specified"}))
@ -27,11 +33,29 @@ def start(port: int) -> None:
r = json.dumps({"is_spam": _spam_detector.is_spam(body['text'])}) r = json.dumps({"is_spam": _spam_detector.is_spam(body['text'])})
self.write(r) self.write(r)
class AdminRestartHandler(tornado.web.RequestHandler):
@tornado.gen.coroutine
def post(self):
global _spam_detector
auth_header = self.request.headers.get('Authorization')
if auth_header:
auth_token = auth_header.split(" ")[1]
else:
auth_token = ''
if auth_token == token:
_spam_detector = _create_spam_detector()
else:
self.set_status(403)
self.write({'status': 'fail', 'message': 'Invalid authentication token'})
self.finish()
return
async def start_web_server(): async def start_web_server():
logger.info(f"Starting web server on port {port}") logger.info(f"Starting web server on port {port}")
app = tornado.web.Application( app = tornado.web.Application(
[ [
(r"/check-spam", CheckSpamHandler), (r"/check-spam", CheckSpamHandler),
(r"/admin/restart", AdminRestartHandler)
], ],
template_path=os.path.join(os.path.dirname(__file__), "templates"), template_path=os.path.join(os.path.dirname(__file__), "templates"),
static_path=os.path.join(os.path.dirname(__file__), "static"), static_path=os.path.join(os.path.dirname(__file__), "static"),