]> git.immae.eu Git - perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git/blobdiff - main.py
Refactor config parsing
[perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git] / main.py
diff --git a/main.py b/main.py
index 446219247cc2f8c9211032d1c03a9f1d96986a40..b68d5408a800ced65e9c29242fabb169e5d481ba 100644 (file)
--- a/main.py
+++ b/main.py
@@ -1,6 +1,5 @@
 from datetime import datetime
-import argparse
-import configparser
+import configargparse
 import psycopg2
 import os
 import sys
@@ -80,31 +79,35 @@ def fetch_markets(pg_config, user):
     for row in cursor:
         yield row
 
-def parse_config(config_file):
-    config = configparser.ConfigParser()
-    config.read(config_file)
+def parse_config(args):
+    pg_config = {
+            "host": args.db_host,
+            "port": args.db_port,
+            "user": args.db_user,
+            "password": args.db_password,
+            "database": args.db_database,
+            }
+    del(args.db_host)
+    del(args.db_port)
+    del(args.db_user)
+    del(args.db_password)
+    del(args.db_database)
 
-    if "postgresql" not in config:
-        print("no configuration for postgresql in config file")
-        sys.exit(1)
+    report_path = args.report_path
 
-    if "app" in config and "report_path" in config["app"]:
-        report_path = config["app"]["report_path"]
+    if report_path is not None and not \
+            os.path.exists(report_path):
+        os.makedirs(report_path)
 
-        if not os.path.exists(report_path):
-            os.makedirs(report_path)
-    else:
-        report_path = None
-
-    return [config["postgresql"], report_path]
+    return pg_config
 
 def parse_args(argv):
-    parser = argparse.ArgumentParser(
-            description="Run the trade bot")
+    parser = configargparse.ArgumentParser(
+            description="Run the trade bot.")
 
     parser.add_argument("-c", "--config",
             default="config.ini",
-            required=False,
+            required=False, is_config_file=True,
             help="Config file to load (default: config.ini)")
     parser.add_argument("--before",
             default=False, action='store_const', const=True,
@@ -125,21 +128,32 @@ def parse_args(argv):
             help="Do a different action than trading (add several times to chain)")
     parser.add_argument("--parallel", action='store_true', default=True, dest="parallel")
     parser.add_argument("--no-parallel", action='store_false', dest="parallel")
-
-    args = parser.parse_args(argv)
-
-    if not os.path.exists(args.config):
-        print("no config file found, exiting")
-        sys.exit(1)
-
-    return args
-
-def process(market_config, market_id, user_id, args, report_path, pg_config):
+    parser.add_argument("--report-db", action='store_true', default=True, dest="report_db",
+            help="Store report to database (default)")
+    parser.add_argument("--no-report-db", action='store_false', dest="report_db",
+            help="Don't store report to database")
+    parser.add_argument("--report-path", required=False,
+            help="Where to store the reports (default: absent, don't store)")
+    parser.add_argument("--no-report-path", action='store_const', dest='report_path', const=None,
+            help="Don't store the report to file (default)")
+    parser.add_argument("--db-host", default="localhost",
+            help="Host access to database (default: localhost)")
+    parser.add_argument("--db-port", default=5432,
+            help="Port access to database (default: 5432)")
+    parser.add_argument("--db-user", default="cryptoportfolio",
+            help="User access to database (default: cryptoportfolio)")
+    parser.add_argument("--db-password", default="cryptoportfolio",
+            help="Password access to database (default: cryptoportfolio)")
+    parser.add_argument("--db-database", default="cryptoportfolio",
+            help="Database access to database (default: cryptoportfolio)")
+
+    return parser.parse_args(argv)
+
+def process(market_config, market_id, user_id, args, pg_config):
     try:
         market.Market\
-                .from_config(market_config, args,
-                        pg_config=pg_config, market_id=market_id,
-                        user_id=user_id, report_path=report_path)\
+                .from_config(market_config, args, market_id=market_id,
+                        pg_config=pg_config, user_id=user_id)\
                 .process(args.action, before=args.before, after=args.after)
     except Exception as e:
         print("{}: {}".format(e.__class__.__name__, e))
@@ -147,7 +161,7 @@ def process(market_config, market_id, user_id, args, report_path, pg_config):
 def main(argv):
     args = parse_args(argv)
 
-    pg_config, report_path = parse_config(args.config)
+    pg_config = parse_config(args)
 
     if args.parallel:
         import threading
@@ -159,7 +173,7 @@ def main(argv):
         process_ = process
 
     for market_id, market_config, user_id in fetch_markets(pg_config, args.user):
-        process_(market_config, market_id, user_id, args, report_path, pg_config)
+        process_(market_config, market_id, user_id, args, pg_config)
 
 if __name__ == '__main__': # pragma: no cover
     main(sys.argv[1:])