60 lines
2.6 KiB
Python
60 lines
2.6 KiB
Python
import argparse
|
|
import os
|
|
import src.model.trainer as model_trainer
|
|
import src.web.server as web_server
|
|
import src.model.updater as model_updater
|
|
from src.transport.rabbitmq import RabbitMQ
|
|
|
|
parser = argparse.ArgumentParser(prog='app.py')
|
|
parser.add_argument('-i', '--init', action=argparse.BooleanOptionalAction, help='Initializing, must be run beforehand, --dataset is required')
|
|
parser.add_argument('-m', '--decision-maker', action=argparse.BooleanOptionalAction, help='Start as Decision maker')
|
|
parser.add_argument('-d', '--dataset', required=False, help='Path to CSV (ham/spam) dataset')
|
|
parser.add_argument('-u', '--model-updater', action=argparse.BooleanOptionalAction, help='Start as Model updater')
|
|
args = parser.parse_args()
|
|
|
|
assert (args.init is not None
|
|
or args.decision_maker is not None
|
|
or args.model_updater is not None), "No mode set. Run --help"
|
|
|
|
_port = 8080 if os.getenv('PORT') is None else os.getenv('PORT')
|
|
_models_dir = os.getenv("MODELS_DIR")
|
|
_fucking_dir = os.getenv("FUCKING_DIR")
|
|
_web_api_url = os.getenv("WEB_API_URL")
|
|
_token = os.getenv("TOKEN")
|
|
_rabbitmq_host = os.getenv("RABBITMQ_HOST")
|
|
_rabbitmq_port = int(os.getenv("RABBITMQ_PORT"))
|
|
_rabbitmq_user = os.getenv("RABBITMQ_USER")
|
|
_rabbitmq_pass = os.getenv("RABBITMQ_PASS")
|
|
_rabbitmq_queue = os.getenv("RABBITMQ_QUEUE")
|
|
|
|
|
|
def start():
|
|
if args.init:
|
|
assert args.dataset is not None, "Dataset is required, run --help"
|
|
dataset_size = os.path.getsize(args.dataset)
|
|
print(f"Dataset size, bytes: {dataset_size}")
|
|
model_trainer.train(args.dataset, fucking_path=_fucking_dir, backup_path=_models_dir)
|
|
elif args.decision_maker:
|
|
rabbitmq = RabbitMQ(_rabbitmq_host, _rabbitmq_port, _rabbitmq_user, _rabbitmq_pass)
|
|
model_trainer.apply_latest(fucking_path=_fucking_dir, backup_path=_models_dir)
|
|
web_server.start(port=_port,
|
|
token=_token,
|
|
fucking_path=_fucking_dir,
|
|
backup_path=_models_dir,
|
|
rabbitmq=rabbitmq,
|
|
queue=_rabbitmq_queue)
|
|
elif args.model_updater:
|
|
rabbitmq = RabbitMQ(_rabbitmq_host, _rabbitmq_port, _rabbitmq_user, _rabbitmq_pass)
|
|
assert args.dataset is not None, "Dataset is required, run --help"
|
|
model_updater.start(fucking_path=_fucking_dir,
|
|
models_dir=_models_dir,
|
|
dataset_path=args.dataset,
|
|
web_api_url=_web_api_url,
|
|
token=_token,
|
|
rabbitmq=rabbitmq,
|
|
queue=_rabbitmq_queue)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
start()
|