]> git.immae.eu Git - perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git/blobdiff - main.py
Refactor databases access
[perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git] / main.py
diff --git a/main.py b/main.py
index a4612075e53cae929e4b5906b61d8a630345ac85..ee25182f24d54bd3326bd976d579c07c0ee896df 100644 (file)
--- a/main.py
+++ b/main.py
@@ -1,5 +1,5 @@
 import configargparse
-import psycopg2
+import dbs
 import os
 import sys
 
@@ -63,15 +63,12 @@ def get_user_market(config_path, user_id, debug=False):
     if debug:
         args.append("--debug")
     args = parse_args(args)
-    pg_config, redis_config = parse_config(args)
-    market_id, market_config, user_id = list(fetch_markets(pg_config, str(user_id)))[0]
-    return market.Market.from_config(market_config, args,
-            pg_config=pg_config, market_id=market_id,
-            user_id=user_id)
+    parse_config(args)
+    market_id, market_config, user_id = list(fetch_markets(str(user_id)))[0]
+    return market.Market.from_config(market_config, args, user_id=user_id)
 
-def fetch_markets(pg_config, user):
-    connection = psycopg2.connect(**pg_config)
-    cursor = connection.cursor()
+def fetch_markets(user):
+    cursor = dbs.psql.cursor()
 
     if user is None:
         cursor.execute("SELECT id,config,user_id FROM market_configs")
@@ -82,30 +79,11 @@ def fetch_markets(pg_config, user):
         yield row
 
 def parse_config(args):
-    pg_config = {
-            "host": args.db_host,
-            "port": args.db_port,
-            "user": args.db_user,
-            "password": args.db_password,
-            "database": args.db_database,
-            }
-    del(args.db_host)
-    del(args.db_port)
-    del(args.db_user)
-    del(args.db_password)
-    del(args.db_database)
-
-    redis_config = {
-            "host": args.redis_host,
-            "port": args.redis_port,
-            "db": args.redis_database,
-            }
-    if redis_config["host"].startswith("/"):
-        redis_config["unix_socket_path"] = redis_config.pop("host")
-        del(redis_config["port"])
-    del(args.redis_host)
-    del(args.redis_port)
-    del(args.redis_database)
+    if args.db_host is not None:
+        dbs.connect_psql(args)
+
+    if args.redis_host is not None:
+        dbs.connect_redis(args)
 
     report_path = args.report_path
 
@@ -113,8 +91,6 @@ def parse_config(args):
             os.path.exists(report_path):
         os.makedirs(report_path)
 
-    return pg_config, redis_config
-
 def parse_args(argv):
     parser = configargparse.ArgumentParser(
             description="Run the trade bot.")
@@ -176,11 +152,10 @@ def parse_args(argv):
         parsed.action = ["sell_all"]
     return parsed
 
-def process(market_config, market_id, user_id, args, pg_config, redis_config):
+def process(market_config, market_id, user_id, args):
     try:
         market.Market\
                 .from_config(market_config, args, market_id=market_id,
-                        pg_config=pg_config, redis_config=redis_config,
                         user_id=user_id)\
                 .process(args.action, before=args.before, after=args.after)
     except Exception as e:
@@ -189,7 +164,7 @@ def process(market_config, market_id, user_id, args, pg_config, redis_config):
 def main(argv):
     args = parse_args(argv)
 
-    pg_config, redis_config = parse_config(args)
+    parse_config(args)
 
     market.Portfolio.report.set_verbose(not args.quiet)
 
@@ -205,8 +180,8 @@ def main(argv):
     else:
         process_ = process
 
-    for market_id, market_config, user_id in fetch_markets(pg_config, args.user):
-        process_(market_config, market_id, user_id, args, pg_config, redis_config)
+    for market_id, market_config, user_id in fetch_markets(args.user):
+        process_(market_config, market_id, user_id, args)
 
     if args.parallel:
         for thread in threads: