working web server and updater
This commit is contained in:
parent
3ad7a72515
commit
c6ab8791ca
@ -1,5 +1,6 @@
|
|||||||
# https://stackoverflow.com/a/57374374/2798461
|
# https://stackoverflow.com/a/57374374/2798461
|
||||||
|
|
||||||
|
#######################################################################################################################
|
||||||
FROM python:3.10-alpine AS dist
|
FROM python:3.10-alpine AS dist
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
@ -10,6 +11,7 @@ COPY requirements.txt /app
|
|||||||
COPY src /app/src
|
COPY src /app/src
|
||||||
|
|
||||||
|
|
||||||
|
#######################################################################################################################
|
||||||
FROM dist AS installation
|
FROM dist AS installation
|
||||||
|
|
||||||
RUN apk update
|
RUN apk update
|
||||||
@ -18,12 +20,14 @@ RUN pip install --no-index --find-links /app/lib nvidia-nccl-cu12
|
|||||||
RUN apk add curl bash gcc g++ cmake make musl-dev
|
RUN apk add curl bash gcc g++ cmake make musl-dev
|
||||||
|
|
||||||
|
|
||||||
|
#######################################################################################################################
|
||||||
FROM installation AS building
|
FROM installation AS building
|
||||||
|
|
||||||
RUN pip install -r requirements.txt
|
RUN pip install -r requirements.txt
|
||||||
RUN apk del gcc g++ cmake make musl-dev
|
RUN apk del gcc g++ cmake make musl-dev
|
||||||
|
|
||||||
|
|
||||||
|
#######################################################################################################################
|
||||||
FROM building AS runner
|
FROM building AS runner
|
||||||
|
|
||||||
RUN apk add libgomp libstdc++
|
RUN apk add libgomp libstdc++
|
||||||
@ -35,12 +39,14 @@ ENV PYTHONPATH=/app
|
|||||||
|
|
||||||
ARG DATASET
|
ARG DATASET
|
||||||
ARG PORT
|
ARG PORT
|
||||||
|
ARG WEB_API_URL
|
||||||
|
|
||||||
ENV PORT=${PORT} \
|
ENV PORT=${PORT} \
|
||||||
DATASET=${DATASET} \
|
DATASET=${DATASET} \
|
||||||
WORKING_DIR=/app/nltk_data \
|
WORKING_DIR=/app/nltk_data \
|
||||||
FUCKING_DIR=/usr/local/lib/python3.10/site-packages/spam_detector_ai/models \
|
FUCKING_DIR=/usr/local/lib/python3.10/site-packages/spam_detector_ai/models \
|
||||||
MODELS_DIR=/app/models
|
MODELS_DIR=/app/models \
|
||||||
|
WEB_API_URL=${WEB_API_URL}
|
||||||
|
|
||||||
RUN python3 -m src.preparer
|
RUN python3 -m src.preparer
|
||||||
|
|
||||||
|
@ -1,3 +1,8 @@
|
|||||||
|
networks:
|
||||||
|
spam-detector-internal:
|
||||||
|
driver: bridge
|
||||||
|
|
||||||
|
|
||||||
services:
|
services:
|
||||||
|
|
||||||
decision-maker:
|
decision-maker:
|
||||||
@ -5,22 +10,31 @@ services:
|
|||||||
container_name: spam-detector-decision-maker
|
container_name: spam-detector-decision-maker
|
||||||
environment:
|
environment:
|
||||||
- PORT=${PORT:-8080}
|
- PORT=${PORT:-8080}
|
||||||
|
- TOKEN=token12345
|
||||||
entrypoint: [ "python", "-m", "src.app", "-m" ]
|
entrypoint: [ "python", "-m", "src.app", "-m" ]
|
||||||
volumes:
|
volumes:
|
||||||
- type: bind
|
- type: bind
|
||||||
source: $DATASET
|
source: $DATASET
|
||||||
target: /app/dataset.csv
|
target: /app/dataset.csv
|
||||||
- ${MODELS_DIR}:/app/models
|
- ${MODELS_DIR}:/app/models
|
||||||
|
restart: always
|
||||||
|
ports:
|
||||||
|
- "8080:${PORT:-8080}"
|
||||||
|
networks:
|
||||||
|
- spam-detector-internal
|
||||||
|
|
||||||
model-updater:
|
model-updater:
|
||||||
build: ./
|
build: ./
|
||||||
container_name: spam-detector-model-updater
|
container_name: spam-detector-model-updater
|
||||||
environment:
|
environment:
|
||||||
- NONE=1
|
- WEB_API_URL=http://spam-detector-decision-maker:${PORT:-8080}
|
||||||
# entrypoint: [ "python", "-m", "src.app", "-u", "-d", "/app/dataset.csv" ]
|
- TOKEN=token12345
|
||||||
entrypoint: [ "bash", "-c", "while true; do sleep 1; done" ]
|
entrypoint: [ "python", "-m", "src.app", "-u", "-d", "/app/dataset.csv" ]
|
||||||
|
# entrypoint: [ "bash", "-c", "while true; do sleep 1; done" ]
|
||||||
volumes:
|
volumes:
|
||||||
- type: bind
|
- type: bind
|
||||||
source: $DATASET
|
source: $DATASET
|
||||||
target: /app/dataset.csv
|
target: /app/dataset.csv
|
||||||
- ${MODELS_DIR}:/app/models/
|
- ${MODELS_DIR}:/app/models/
|
||||||
|
networks:
|
||||||
|
- spam-detector-internal
|
||||||
|
@ -19,6 +19,9 @@ assert (args.init is not None
|
|||||||
_port = 8080 if os.getenv('PORT') is None else os.getenv('PORT')
|
_port = 8080 if os.getenv('PORT') is None else os.getenv('PORT')
|
||||||
_models_dir = os.getenv("MODELS_DIR")
|
_models_dir = os.getenv("MODELS_DIR")
|
||||||
_fucking_dir = os.getenv("FUCKING_DIR")
|
_fucking_dir = os.getenv("FUCKING_DIR")
|
||||||
|
_web_api_url = os.getenv("WEB_API_URL")
|
||||||
|
_token = os.getenv("TOKEN")
|
||||||
|
|
||||||
|
|
||||||
def start():
|
def start():
|
||||||
if args.init:
|
if args.init:
|
||||||
@ -26,10 +29,10 @@ def start():
|
|||||||
model_trainer.train(args.dataset, fucking_path=_fucking_dir, backup_path=_models_dir)
|
model_trainer.train(args.dataset, fucking_path=_fucking_dir, backup_path=_models_dir)
|
||||||
elif args.decision_maker:
|
elif args.decision_maker:
|
||||||
model_trainer.apply_latest(fucking_path=_fucking_dir, backup_path=_models_dir)
|
model_trainer.apply_latest(fucking_path=_fucking_dir, backup_path=_models_dir)
|
||||||
web_server.start(port=_port)
|
web_server.start(port=_port, token=_token, fucking_path=_fucking_dir, backup_path=_models_dir)
|
||||||
elif args.model_updater:
|
elif args.model_updater:
|
||||||
assert args.dataset is not None, "Dataset is required, run --help"
|
assert args.dataset is not None, "Dataset is required, run --help"
|
||||||
model_updater.start(fucking_path=_fucking_dir, models_dir=_models_dir, dataset_path=args.dataset)
|
model_updater.start(fucking_path=_fucking_dir, models_dir=_models_dir, dataset_path=args.dataset, web_api_url=_web_api_url, token=_token)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -4,23 +4,36 @@ import src.model.trainer as trainer
|
|||||||
from scheduler import Scheduler
|
from scheduler import Scheduler
|
||||||
from os import listdir
|
from os import listdir
|
||||||
import time
|
import time
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
def _does_file_exist_in_dir(path):
|
def _does_file_exist_in_dir(path):
|
||||||
return len(listdir(path)) > 0
|
return len(listdir(path)) > 0
|
||||||
|
|
||||||
def start(fucking_path: str, models_dir: str, dataset_path: str) -> None:
|
def start(fucking_path: str, models_dir: str, dataset_path: str, web_api_url: str, token: str) -> None:
|
||||||
logger.info("Starting...")
|
logger.info("Starting...")
|
||||||
|
|
||||||
def train() -> None:
|
def _restart_web_api() -> None:
|
||||||
trainer.train(dataset_path, fucking_path=fucking_path, backup_path=models_dir)
|
headers = { 'Authorization': f"Bearer {token}" }
|
||||||
|
response = requests.post(
|
||||||
|
f"{web_api_url}/admin/restart",
|
||||||
|
json={},
|
||||||
|
headers=headers,
|
||||||
|
timeout=3
|
||||||
|
)
|
||||||
|
if response.status_code > 399:
|
||||||
|
logger.warn(f"Unable to restart Web API server: {response.status_code}, {response.text}")
|
||||||
|
|
||||||
|
def _train() -> None:
|
||||||
|
trainer.train(dataset_path=dataset_path, fucking_path=fucking_path, backup_path=models_dir)
|
||||||
|
_restart_web_api()
|
||||||
|
|
||||||
tz_moscow = dt.timezone(dt.timedelta(hours=3))
|
tz_moscow = dt.timezone(dt.timedelta(hours=3))
|
||||||
scheduler = Scheduler(tzinfo=dt.timezone.utc)
|
scheduler = Scheduler(tzinfo=dt.timezone.utc)
|
||||||
if not _does_file_exist_in_dir(models_dir):
|
if not _does_file_exist_in_dir(models_dir):
|
||||||
logger.info("Will be updated in 5 seconds...")
|
logger.info("Will be updated in 5 seconds...")
|
||||||
scheduler.once(dt.timedelta(seconds=5), train)
|
scheduler.once(dt.timedelta(seconds=5), _train)
|
||||||
scheduler.daily(dt.time(hour=3, minute=0, tzinfo=tz_moscow), train)
|
scheduler.daily(dt.time(hour=3, minute=0, tzinfo=tz_moscow), _train)
|
||||||
print(scheduler)
|
print(scheduler)
|
||||||
while True:
|
while True:
|
||||||
scheduler.exec_jobs()
|
scheduler.exec_jobs()
|
||||||
|
@ -2,24 +2,30 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import tornado
|
import tornado
|
||||||
from spam_detector_ai.prediction.predict import VotingSpamDetector
|
import src.model.trainer as model_trainer
|
||||||
from src.l.logger import logger
|
from src.l.logger import logger
|
||||||
|
|
||||||
|
|
||||||
_spam_detector = VotingSpamDetector()
|
_spam_detector = None
|
||||||
|
|
||||||
|
|
||||||
def _json(data) -> str:
|
def _json(data) -> str:
|
||||||
return json.dumps(data)
|
return json.dumps(data)
|
||||||
|
|
||||||
def start(port: int) -> None:
|
def start(port: int, token: str, fucking_path: str, backup_path: str) -> None:
|
||||||
|
global _spam_detector
|
||||||
logger.info("Starting...")
|
logger.info("Starting...")
|
||||||
|
|
||||||
class CheckSpamHandler(tornado.web.RequestHandler):
|
def _create_spam_detector():
|
||||||
def set_default_headers(self):
|
model_trainer.apply_latest(fucking_path=fucking_path, backup_path=backup_path)
|
||||||
self.set_header("Access-Control-Allow-Origin", "*")
|
from spam_detector_ai.prediction.predict import VotingSpamDetector
|
||||||
|
return VotingSpamDetector()
|
||||||
|
|
||||||
def get(self):
|
_spam_detector = _create_spam_detector()
|
||||||
|
|
||||||
|
class CheckSpamHandler(tornado.web.RequestHandler):
|
||||||
|
@tornado.gen.coroutine
|
||||||
|
def post(self):
|
||||||
body = json.loads(self.request.body)
|
body = json.loads(self.request.body)
|
||||||
if not 'text' in body:
|
if not 'text' in body:
|
||||||
self.write_error(400, body=_json({"error": "text is not specified"}))
|
self.write_error(400, body=_json({"error": "text is not specified"}))
|
||||||
@ -27,11 +33,29 @@ def start(port: int) -> None:
|
|||||||
r = json.dumps({"is_spam": _spam_detector.is_spam(body['text'])})
|
r = json.dumps({"is_spam": _spam_detector.is_spam(body['text'])})
|
||||||
self.write(r)
|
self.write(r)
|
||||||
|
|
||||||
|
class AdminRestartHandler(tornado.web.RequestHandler):
|
||||||
|
@tornado.gen.coroutine
|
||||||
|
def post(self):
|
||||||
|
global _spam_detector
|
||||||
|
auth_header = self.request.headers.get('Authorization')
|
||||||
|
if auth_header:
|
||||||
|
auth_token = auth_header.split(" ")[1]
|
||||||
|
else:
|
||||||
|
auth_token = ''
|
||||||
|
if auth_token == token:
|
||||||
|
_spam_detector = _create_spam_detector()
|
||||||
|
else:
|
||||||
|
self.set_status(403)
|
||||||
|
self.write({'status': 'fail', 'message': 'Invalid authentication token'})
|
||||||
|
self.finish()
|
||||||
|
return
|
||||||
|
|
||||||
async def start_web_server():
|
async def start_web_server():
|
||||||
logger.info(f"Starting web server on port {port}")
|
logger.info(f"Starting web server on port {port}")
|
||||||
app = tornado.web.Application(
|
app = tornado.web.Application(
|
||||||
[
|
[
|
||||||
(r"/check-spam", CheckSpamHandler),
|
(r"/check-spam", CheckSpamHandler),
|
||||||
|
(r"/admin/restart", AdminRestartHandler)
|
||||||
],
|
],
|
||||||
template_path=os.path.join(os.path.dirname(__file__), "templates"),
|
template_path=os.path.join(os.path.dirname(__file__), "templates"),
|
||||||
static_path=os.path.join(os.path.dirname(__file__), "static"),
|
static_path=os.path.join(os.path.dirname(__file__), "static"),
|
||||||
|
Loading…
Reference in New Issue
Block a user