]> git.immae.eu Git - perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git/commitdiff
Cleanup market from_config
authorIsmaël Bouya <ismael.bouya@normalesup.org>
Fri, 23 Mar 2018 22:33:36 +0000 (23:33 +0100)
committerIsmaël Bouya <ismael.bouya@normalesup.org>
Sat, 24 Mar 2018 09:39:52 +0000 (10:39 +0100)
main.py
market.py
test.py

diff --git a/main.py b/main.py
index 3e9828952867dc7034fb588d865d713655d1a112..446219247cc2f8c9211032d1c03a9f1d96986a40 100644 (file)
--- a/main.py
+++ b/main.py
@@ -62,9 +62,11 @@ def make_order(market, value, currency, action="acquire",
 
 def get_user_market(config_path, user_id, debug=False):
     pg_config, report_path = parse_config(config_path)
-    market_config = list(fetch_markets(pg_config, str(user_id)))[0][1]
+    market_id, market_config, user_id = list(fetch_markets(pg_config, str(user_id)))[0]
     args = type('Args', (object,), { "debug": debug, "quiet": False })()
-    return market.Market.from_config(market_config, args, user_id=user_id, report_path=report_path)
+    return market.Market.from_config(market_config, args,
+            pg_config=pg_config, market_id=market_id,
+            user_id=user_id, report_path=report_path)
 
 def fetch_markets(pg_config, user):
     connection = psycopg2.connect(**pg_config)
@@ -132,7 +134,7 @@ def parse_args(argv):
 
     return args
 
-def process(market_id, market_config, user_id, report_path, args, pg_config):
+def process(market_config, market_id, user_id, args, report_path, pg_config):
     try:
         market.Market\
                 .from_config(market_config, args,
@@ -151,13 +153,13 @@ def main(argv):
         import threading
         market.Portfolio.start_worker()
 
-        for row in fetch_markets(pg_config, args.user):
-            threading.Thread(target=process, args=[
-                *row, report_path, args, pg_config
-                ]).start()
+        def process_(*args):
+            threading.Thread(target=process, args=args).start()
     else:
-        for row in fetch_markets(pg_config, args.user):
-            process(*row, report_path, args, pg_config)
+        process_ = process
+
+    for market_id, market_config, user_id in fetch_markets(pg_config, args.user):
+        process_(market_config, market_id, user_id, args, report_path, pg_config)
 
 if __name__ == '__main__': # pragma: no cover
     main(sys.argv[1:])
index 78ced1a209eea10c181dfd429b238ff7ca30c659..496ec45843319f2145f955e811ea49e1842e4761 100644 (file)
--- a/market.py
+++ b/market.py
@@ -14,9 +14,7 @@ class Market:
     trades = None
     balances = None
 
-    def __init__(self, ccxt_instance, args,
-            user_id=None, market_id=None,
-            report_path=None, pg_config=None):
+    def __init__(self, ccxt_instance, args, **kwargs):
         self.args = args
         self.debug = args.debug
         self.ccxt = ccxt_instance
@@ -26,14 +24,11 @@ class Market:
         self.balances = BalanceStore(self)
         self.processor = Processor(self)
 
-        self.user_id = user_id
-        self.market_id = market_id
-        self.report_path = report_path
-        self.pg_config = pg_config
+        for key in ["user_id", "market_id", "report_path", "pg_config"]:
+            setattr(self, key, kwargs.get(key, None))
 
     @classmethod
-    def from_config(cls, config, args,
-            user_id=None, market_id=None, report_path=None, pg_config=None):
+    def from_config(cls, config, args, **kwargs):
         config["apiKey"] = config.pop("key", None)
 
         ccxt_instance = ccxt.poloniexE(config)
@@ -50,9 +45,7 @@ class Market:
         ccxt_instance.session.request = request_wrap.__get__(ccxt_instance.session,
                 ccxt_instance.session.__class__)
 
-        return cls(ccxt_instance, args,
-                user_id=user_id, market_id=market_id,
-                pg_config=pg_config, report_path=report_path)
+        return cls(ccxt_instance, args, **kwargs)
 
     def store_report(self):
         self.report.merge(Portfolio.report)
diff --git a/test.py b/test.py
index 5b9c56c02effd0f7a983f78b7c9b161f614e38bd..637a3054c4d8449b3fb0963e4180e62b0b837567 100644 (file)
--- a/test.py
+++ b/test.py
@@ -3682,7 +3682,7 @@ class MainTest(WebMockTestCase):
             args_mock.after = "after"
             self.assertEqual("", stdout_mock.getvalue())
 
-            main.process(3, "config", 1, "report_path", args_mock, "pg_config")
+            main.process("config", 3, 1, args_mock, "report_path", "pg_config")
 
             market_mock.from_config.assert_has_calls([
                 mock.call("config", args_mock, pg_config="pg_config", market_id=3, user_id=1, report_path="report_path"),
@@ -3719,8 +3719,8 @@ class MainTest(WebMockTestCase):
 
                 self.assertEqual(2, process.call_count)
                 process.assert_has_calls([
-                    mock.call(3, "config1", 1, "report_path", args_mock, "pg_config"),
-                    mock.call(1, "config2", 2, "report_path", args_mock, "pg_config"),
+                    mock.call("config1", 3, 1, args_mock, "report_path", "pg_config"),
+                    mock.call("config2", 1, 2, args_mock, "report_path", "pg_config"),
                     ])
         with self.subTest(parallel=True):
             with mock.patch("main.parse_args") as parse_args,\
@@ -3749,9 +3749,9 @@ class MainTest(WebMockTestCase):
                 self.assertEqual(2, process.call_count)
                 process.assert_has_calls([
                     mock.call.__bool__(),
-                    mock.call(3, "config1", 1, "report_path", args_mock, "pg_config"),
+                    mock.call("config1", 3, 1, args_mock, "report_path", "pg_config"),
                     mock.call.__bool__(),
-                    mock.call(1, "config2", 2, "report_path", args_mock, "pg_config"),
+                    mock.call("config2", 1, 2, args_mock, "report_path", "pg_config"),
                     ])
 
     @mock.patch.object(main.sys, "exit")