working web server and updater
This commit is contained in:
parent
3ad7a72515
commit
c6ab8791ca
@ -1,5 +1,6 @@
|
||||
# https://stackoverflow.com/a/57374374/2798461
|
||||
|
||||
#######################################################################################################################
|
||||
FROM python:3.10-alpine AS dist
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
@ -10,6 +11,7 @@ COPY requirements.txt /app
|
||||
COPY src /app/src
|
||||
|
||||
|
||||
#######################################################################################################################
|
||||
FROM dist AS installation
|
||||
|
||||
RUN apk update
|
||||
@ -18,12 +20,14 @@ RUN pip install --no-index --find-links /app/lib nvidia-nccl-cu12
|
||||
RUN apk add curl bash gcc g++ cmake make musl-dev
|
||||
|
||||
|
||||
#######################################################################################################################
|
||||
FROM installation AS building
|
||||
|
||||
RUN pip install -r requirements.txt
|
||||
RUN apk del gcc g++ cmake make musl-dev
|
||||
|
||||
|
||||
#######################################################################################################################
|
||||
FROM building AS runner
|
||||
|
||||
RUN apk add libgomp libstdc++
|
||||
@ -35,12 +39,14 @@ ENV PYTHONPATH=/app
|
||||
|
||||
ARG DATASET
|
||||
ARG PORT
|
||||
ARG WEB_API_URL
|
||||
|
||||
ENV PORT=${PORT} \
|
||||
DATASET=${DATASET} \
|
||||
WORKING_DIR=/app/nltk_data \
|
||||
FUCKING_DIR=/usr/local/lib/python3.10/site-packages/spam_detector_ai/models \
|
||||
MODELS_DIR=/app/models
|
||||
MODELS_DIR=/app/models \
|
||||
WEB_API_URL=${WEB_API_URL}
|
||||
|
||||
RUN python3 -m src.preparer
|
||||
|
||||
|
@ -1,3 +1,8 @@
|
||||
networks:
|
||||
spam-detector-internal:
|
||||
driver: bridge
|
||||
|
||||
|
||||
services:
|
||||
|
||||
decision-maker:
|
||||
@ -5,22 +10,31 @@ services:
|
||||
container_name: spam-detector-decision-maker
|
||||
environment:
|
||||
- PORT=${PORT:-8080}
|
||||
- TOKEN=token12345
|
||||
entrypoint: [ "python", "-m", "src.app", "-m" ]
|
||||
volumes:
|
||||
- type: bind
|
||||
source: $DATASET
|
||||
target: /app/dataset.csv
|
||||
- ${MODELS_DIR}:/app/models
|
||||
restart: always
|
||||
ports:
|
||||
- "8080:${PORT:-8080}"
|
||||
networks:
|
||||
- spam-detector-internal
|
||||
|
||||
model-updater:
|
||||
build: ./
|
||||
container_name: spam-detector-model-updater
|
||||
environment:
|
||||
- NONE=1
|
||||
# entrypoint: [ "python", "-m", "src.app", "-u", "-d", "/app/dataset.csv" ]
|
||||
entrypoint: [ "bash", "-c", "while true; do sleep 1; done" ]
|
||||
- WEB_API_URL=http://spam-detector-decision-maker:${PORT:-8080}
|
||||
- TOKEN=token12345
|
||||
entrypoint: [ "python", "-m", "src.app", "-u", "-d", "/app/dataset.csv" ]
|
||||
# entrypoint: [ "bash", "-c", "while true; do sleep 1; done" ]
|
||||
volumes:
|
||||
- type: bind
|
||||
source: $DATASET
|
||||
target: /app/dataset.csv
|
||||
- ${MODELS_DIR}:/app/models/
|
||||
networks:
|
||||
- spam-detector-internal
|
||||
|
@ -19,6 +19,9 @@ assert (args.init is not None
|
||||
_port = 8080 if os.getenv('PORT') is None else os.getenv('PORT')
|
||||
_models_dir = os.getenv("MODELS_DIR")
|
||||
_fucking_dir = os.getenv("FUCKING_DIR")
|
||||
_web_api_url = os.getenv("WEB_API_URL")
|
||||
_token = os.getenv("TOKEN")
|
||||
|
||||
|
||||
def start():
|
||||
if args.init:
|
||||
@ -26,10 +29,10 @@ def start():
|
||||
model_trainer.train(args.dataset, fucking_path=_fucking_dir, backup_path=_models_dir)
|
||||
elif args.decision_maker:
|
||||
model_trainer.apply_latest(fucking_path=_fucking_dir, backup_path=_models_dir)
|
||||
web_server.start(port=_port)
|
||||
web_server.start(port=_port, token=_token, fucking_path=_fucking_dir, backup_path=_models_dir)
|
||||
elif args.model_updater:
|
||||
assert args.dataset is not None, "Dataset is required, run --help"
|
||||
model_updater.start(fucking_path=_fucking_dir, models_dir=_models_dir, dataset_path=args.dataset)
|
||||
model_updater.start(fucking_path=_fucking_dir, models_dir=_models_dir, dataset_path=args.dataset, web_api_url=_web_api_url, token=_token)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -4,23 +4,36 @@ import src.model.trainer as trainer
|
||||
from scheduler import Scheduler
|
||||
from os import listdir
|
||||
import time
|
||||
import requests
|
||||
|
||||
|
||||
def _does_file_exist_in_dir(path):
|
||||
return len(listdir(path)) > 0
|
||||
|
||||
def start(fucking_path: str, models_dir: str, dataset_path: str) -> None:
|
||||
def start(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str) -> None:
|
||||
logger.info("Starting...")
|
||||
|
||||
def train() -> None:
|
||||
trainer.train(dataset_path, fucking_path=fucking_path, backup_path=models_dir)
|
||||
def _restart_web_api() -> None:
|
||||
headers = { 'Authorization': f"Bearer {token}" }
|
||||
response = requests.post(
|
||||
f"{web_api_url}/admin/restart",
|
||||
json={},
|
||||
headers=headers,
|
||||
timeout=3
|
||||
)
|
||||
if response.status_code > 399:
|
||||
logger.warn(f"Unable to restart Web API server: {response.status_code}, {response.text}")
|
||||
|
||||
def _train() -> None:
|
||||
trainer.train(dataset_path=dataset_path, fucking_path=fucking_path, backup_path=models_dir)
|
||||
_restart_web_api()
|
||||
|
||||
tz_moscow = dt.timezone(dt.timedelta(hours=3))
|
||||
scheduler = Scheduler(tzinfo=dt.timezone.utc)
|
||||
if not _does_file_exist_in_dir(models_dir):
|
||||
logger.info("Will be updated in 5 seconds...")
|
||||
scheduler.once(dt.timedelta(seconds=5), train)
|
||||
scheduler.daily(dt.time(hour=3, minute=0, tzinfo=tz_moscow), train)
|
||||
scheduler.once(dt.timedelta(seconds=5), _train)
|
||||
scheduler.daily(dt.time(hour=3, minute=0, tzinfo=tz_moscow), _train)
|
||||
print(scheduler)
|
||||
while True:
|
||||
scheduler.exec_jobs()
|
||||
|
@ -2,24 +2,30 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
import tornado
|
||||
from spam_detector_ai.prediction.predict import VotingSpamDetector
|
||||
import src.model.trainer as model_trainer
|
||||
from src.l.logger import logger
|
||||
|
||||
|
||||
_spam_detector = VotingSpamDetector()
|
||||
_spam_detector = None
|
||||
|
||||
|
||||
def _json(data) -> str:
|
||||
return json.dumps(data)
|
||||
|
||||
def start(port: int) -> None:
|
||||
def start(port: int, token: str, fucking_path: str, backup_path: str) -> None:
|
||||
global _spam_detector
|
||||
logger.info("Starting...")
|
||||
|
||||
class CheckSpamHandler(tornado.web.RequestHandler):
|
||||
def set_default_headers(self):
|
||||
self.set_header("Access-Control-Allow-Origin", "*")
|
||||
def _create_spam_detector():
|
||||
model_trainer.apply_latest(fucking_path=fucking_path, backup_path=backup_path)
|
||||
from spam_detector_ai.prediction.predict import VotingSpamDetector
|
||||
return VotingSpamDetector()
|
||||
|
||||
def get(self):
|
||||
_spam_detector = _create_spam_detector()
|
||||
|
||||
class CheckSpamHandler(tornado.web.RequestHandler):
|
||||
@tornado.gen.coroutine
|
||||
def post(self):
|
||||
body = json.loads(self.request.body)
|
||||
if not 'text' in body:
|
||||
self.write_error(400, body=_json({"error": "text is not specified"}))
|
||||
@ -27,11 +33,29 @@ def start(port: int) -> None:
|
||||
r = json.dumps({"is_spam": _spam_detector.is_spam(body['text'])})
|
||||
self.write(r)
|
||||
|
||||
class AdminRestartHandler(tornado.web.RequestHandler):
|
||||
@tornado.gen.coroutine
|
||||
def post(self):
|
||||
global _spam_detector
|
||||
auth_header = self.request.headers.get('Authorization')
|
||||
if auth_header:
|
||||
auth_token = auth_header.split(" ")[1]
|
||||
else:
|
||||
auth_token = ''
|
||||
if auth_token == token:
|
||||
_spam_detector = _create_spam_detector()
|
||||
else:
|
||||
self.set_status(403)
|
||||
self.write({'status': 'fail', 'message': 'Invalid authentication token'})
|
||||
self.finish()
|
||||
return
|
||||
|
||||
async def start_web_server():
|
||||
logger.info(f"Starting web server on port {port}")
|
||||
app = tornado.web.Application(
|
||||
[
|
||||
(r"/check-spam", CheckSpamHandler),
|
||||
(r"/admin/restart", AdminRestartHandler)
|
||||
],
|
||||
template_path=os.path.join(os.path.dirname(__file__), "templates"),
|
||||
static_path=os.path.join(os.path.dirname(__file__), "static"),
|
||||
|
Loading…
Reference in New Issue
Block a user