]> git.immae.eu Git - perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git/commitdiff
Refactor config parsing
authorIsmaël Bouya <ismael.bouya@normalesup.org>
Thu, 5 Apr 2018 07:56:51 +0000 (09:56 +0200)
committerIsmaël Bouya <ismael.bouya@normalesup.org>
Thu, 5 Apr 2018 07:56:51 +0000 (09:56 +0200)
config.ini
main.py
market.py
requirements.txt
store.py
test.py

index 50cbd1b74523d42c202c0709751ff278f1f04208..4d261c47f5717694f4cce51b7944a173f7e707c4 100644 (file)
@@ -1,9 +1,9 @@
 [postgresql]
-host = localhost
-port = 5432
-user = cryptoportfolio
-password = cryptoportfolio
-database = cryptoportfolio
+db-host = localhost
+db-port = 5432
+db-user = cryptoportfolio
+db-password = cryptoportfolio
+db-database = cryptoportfolio
 
 [app]
-report_path = reports
+report-path = reports
diff --git a/main.py b/main.py
index 446219247cc2f8c9211032d1c03a9f1d96986a40..b68d5408a800ced65e9c29242fabb169e5d481ba 100644 (file)
--- a/main.py
+++ b/main.py
@@ -1,6 +1,5 @@
 from datetime import datetime
-import argparse
-import configparser
+import configargparse
 import psycopg2
 import os
 import sys
@@ -80,31 +79,35 @@ def fetch_markets(pg_config, user):
     for row in cursor:
         yield row
 
-def parse_config(config_file):
-    config = configparser.ConfigParser()
-    config.read(config_file)
+def parse_config(args):
+    pg_config = {
+            "host": args.db_host,
+            "port": args.db_port,
+            "user": args.db_user,
+            "password": args.db_password,
+            "database": args.db_database,
+            }
+    del(args.db_host)
+    del(args.db_port)
+    del(args.db_user)
+    del(args.db_password)
+    del(args.db_database)
 
-    if "postgresql" not in config:
-        print("no configuration for postgresql in config file")
-        sys.exit(1)
+    report_path = args.report_path
 
-    if "app" in config and "report_path" in config["app"]:
-        report_path = config["app"]["report_path"]
+    if report_path is not None and not \
+            os.path.exists(report_path):
+        os.makedirs(report_path)
 
-        if not os.path.exists(report_path):
-            os.makedirs(report_path)
-    else:
-        report_path = None
-
-    return [config["postgresql"], report_path]
+    return pg_config
 
 def parse_args(argv):
-    parser = argparse.ArgumentParser(
-            description="Run the trade bot")
+    parser = configargparse.ArgumentParser(
+            description="Run the trade bot.")
 
     parser.add_argument("-c", "--config",
             default="config.ini",
-            required=False,
+            required=False, is_config_file=True,
             help="Config file to load (default: config.ini)")
     parser.add_argument("--before",
             default=False, action='store_const', const=True,
@@ -125,21 +128,32 @@ def parse_args(argv):
             help="Do a different action than trading (add several times to chain)")
     parser.add_argument("--parallel", action='store_true', default=True, dest="parallel")
     parser.add_argument("--no-parallel", action='store_false', dest="parallel")
-
-    args = parser.parse_args(argv)
-
-    if not os.path.exists(args.config):
-        print("no config file found, exiting")
-        sys.exit(1)
-
-    return args
-
-def process(market_config, market_id, user_id, args, report_path, pg_config):
+    parser.add_argument("--report-db", action='store_true', default=True, dest="report_db",
+            help="Store report to database (default)")
+    parser.add_argument("--no-report-db", action='store_false', dest="report_db",
+            help="Don't store report to database")
+    parser.add_argument("--report-path", required=False,
+            help="Where to store the reports (default: absent, don't store)")
+    parser.add_argument("--no-report-path", action='store_const', dest='report_path', const=None,
+            help="Don't store the report to file (default)")
+    parser.add_argument("--db-host", default="localhost",
+            help="Host access to database (default: localhost)")
+    parser.add_argument("--db-port", default=5432,
+            help="Port access to database (default: 5432)")
+    parser.add_argument("--db-user", default="cryptoportfolio",
+            help="User access to database (default: cryptoportfolio)")
+    parser.add_argument("--db-password", default="cryptoportfolio",
+            help="Password access to database (default: cryptoportfolio)")
+    parser.add_argument("--db-database", default="cryptoportfolio",
+            help="Database access to database (default: cryptoportfolio)")
+
+    return parser.parse_args(argv)
+
+def process(market_config, market_id, user_id, args, pg_config):
     try:
         market.Market\
-                .from_config(market_config, args,
-                        pg_config=pg_config, market_id=market_id,
-                        user_id=user_id, report_path=report_path)\
+                .from_config(market_config, args, market_id=market_id,
+                        pg_config=pg_config, user_id=user_id)\
                 .process(args.action, before=args.before, after=args.after)
     except Exception as e:
         print("{}: {}".format(e.__class__.__name__, e))
@@ -147,7 +161,7 @@ def process(market_config, market_id, user_id, args, report_path, pg_config):
 def main(argv):
     args = parse_args(argv)
 
-    pg_config, report_path = parse_config(args.config)
+    pg_config = parse_config(args)
 
     if args.parallel:
         import threading
@@ -159,7 +173,7 @@ def main(argv):
         process_ = process
 
     for market_id, market_config, user_id in fetch_markets(pg_config, args.user):
-        process_(market_config, market_id, user_id, args, report_path, pg_config)
+        process_(market_config, market_id, user_id, args, pg_config)
 
 if __name__ == '__main__': # pragma: no cover
     main(sys.argv[1:])
index 10d1ad8936b3f671ca2ca7e9a1623a856dce29db..e16641c476865bc3977ceaa8f30ccce5296925ab 100644 (file)
--- a/market.py
+++ b/market.py
@@ -25,11 +25,10 @@ class Market:
         self.balances = BalanceStore(self)
         self.processor = Processor(self)
 
-        for key in ["user_id", "market_id", "report_path", "pg_config"]:
+        for key in ["user_id", "market_id", "pg_config"]:
             setattr(self, key, kwargs.get(key, None))
 
-        self.report.log_market(self.args, self.user_id, self.market_id,
-                self.report_path, self.debug)
+        self.report.log_market(self.args, self.user_id, self.market_id)
 
     @classmethod
     def from_config(cls, config, args, **kwargs):
@@ -42,14 +41,14 @@ class Market:
     def store_report(self):
         self.report.merge(Portfolio.report)
         date = datetime.now()
-        if self.report_path is not None:
+        if self.args.report_path is not None:
             self.store_file_report(date)
-        if self.pg_config is not None:
+        if self.pg_config is not None and self.args.report_db:
             self.store_database_report(date)
 
     def store_file_report(self, date):
         try:
-            report_file = "{}/{}_{}".format(self.report_path, date.isoformat(), self.user_id)
+            report_file = "{}/{}_{}".format(self.args.report_path, date.isoformat(), self.user_id)
             with open(report_file + ".json", "w") as f:
                 f.write(self.report.to_json())
             with open(report_file + ".log", "w") as f:
index 1bc76ec67f0b9289c791cab50ece5625d277e90c..2451c805cc14e46a8d1b7f40226ca6ee931bcb56 100644 (file)
@@ -5,3 +5,4 @@ requests_mock==1.4.0
 psycopg2==2.7.4
 retry==0.9.2
 cachetools==2.0.1
+configargparse==0.12.0
index 3f3718f4f8c0bff4e253a0a090e5820bd8065726..67e8a8fad7f9ce3698095914351eb4602fe7564d 100644 (file)
--- a/store.py
+++ b/store.py
@@ -222,15 +222,13 @@ class ReportStore:
             "action": action,
             })
 
-    def log_market(self, args, user_id, market_id, report_path, debug):
+    def log_market(self, args, user_id, market_id):
         self.add_log({
             "type": "market",
             "commit": "$Format:%H$",
             "args": vars(args),
             "user_id": user_id,
             "market_id": market_id,
-            "report_path": report_path,
-            "debug": debug,
             })
 
 class BalanceStore:
diff --git a/test.py b/test.py
index bf679bfc8c2879f507bdfee64b57c09705619845..854e27b1089bf6f985778242a550eca9340a0c25 100644 (file)
--- a/test.py
+++ b/test.py
@@ -23,8 +23,9 @@ for test_type in limits:
 class WebMockTestCase(unittest.TestCase):
     import time
 
-    def market_args(self, debug=False, quiet=False):
-        return type('Args', (object,), { "debug": debug, "quiet": quiet })()
+    def market_args(self, debug=False, quiet=False, report_path=None, **kwargs):
+        return main.configargparse.Namespace(report_path=report_path,
+                debug=debug, quiet=quiet, **kwargs)
 
     def setUp(self):
         super().setUp()
@@ -1632,7 +1633,8 @@ class MarketTest(WebMockTestCase):
 
     def test_store_file_report(self):
         file_open = mock.mock_open()
-        m = market.Market(self.ccxt, self.market_args(), report_path="present", user_id=1)
+        m = market.Market(self.ccxt,
+                self.market_args(report_path="present"), user_id=1)
         with self.subTest(file="present"),\
                 mock.patch("market.open", file_open),\
                 mock.patch.object(m, "report") as report,\
@@ -1649,7 +1651,7 @@ class MarketTest(WebMockTestCase):
             file_open().write.assert_any_call("Foo\nBar")
             m.report.to_json.assert_called_once_with()
 
-        m = market.Market(self.ccxt, self.market_args(), report_path="error", user_id=1)
+        m = market.Market(self.ccxt, self.market_args(report_path="error"), user_id=1)
         with self.subTest(file="error"),\
                 mock.patch("market.open") as file_open,\
                 mock.patch.object(m, "report") as report,\
@@ -1697,7 +1699,7 @@ class MarketTest(WebMockTestCase):
             self.assertEqual(stdout_mock.getvalue(), "impossible to store report to database: Exception; Bouh\n")
 
     def test_store_report(self):
-        m = market.Market(self.ccxt, self.market_args(), user_id=1)
+        m = market.Market(self.ccxt, self.market_args(report_db=False), user_id=1)
         with self.subTest(file=None, pg_config=None),\
                 mock.patch.object(m, "report") as report,\
                 mock.patch.object(m, "store_database_report") as db_report,\
@@ -1709,7 +1711,7 @@ class MarketTest(WebMockTestCase):
             db_report.assert_not_called()
 
         report.reset_mock()
-        m = market.Market(self.ccxt, self.market_args(), report_path="present", user_id=1)
+        m = market.Market(self.ccxt, self.market_args(report_db=False, report_path="present"), user_id=1)
         with self.subTest(file="present", pg_config=None),\
                 mock.patch.object(m, "report") as report,\
                 mock.patch.object(m, "store_file_report") as file_report,\
@@ -1725,7 +1727,23 @@ class MarketTest(WebMockTestCase):
             db_report.assert_not_called()
 
         report.reset_mock()
-        m = market.Market(self.ccxt, self.market_args(), pg_config="present", user_id=1)
+        m = market.Market(self.ccxt, self.market_args(report_db=True, report_path="present"), user_id=1)
+        with self.subTest(file="present", pg_config=None, report_db=True),\
+                mock.patch.object(m, "report") as report,\
+                mock.patch.object(m, "store_file_report") as file_report,\
+                mock.patch.object(m, "store_database_report") as db_report,\
+                mock.patch.object(market, "datetime") as time_mock:
+
+            time_mock.now.return_value = datetime.datetime(2018, 2, 25)
+
+            m.store_report()
+
+            report.merge.assert_called_with(store.Portfolio.report)
+            file_report.assert_called_once_with(datetime.datetime(2018, 2, 25))
+            db_report.assert_not_called()
+
+        report.reset_mock()
+        m = market.Market(self.ccxt, self.market_args(report_db=True), pg_config="present", user_id=1)
         with self.subTest(file=None, pg_config="present"),\
                 mock.patch.object(m, "report") as report,\
                 mock.patch.object(m, "store_file_report") as file_report,\
@@ -1741,8 +1759,8 @@ class MarketTest(WebMockTestCase):
             db_report.assert_called_once_with(datetime.datetime(2018, 2, 25))
 
         report.reset_mock()
-        m = market.Market(self.ccxt, self.market_args(),
-                pg_config="pg_config", report_path="present", user_id=1)
+        m = market.Market(self.ccxt, self.market_args(report_db=True, report_path="present"),
+                pg_config="pg_config", user_id=1)
         with self.subTest(file="present", pg_config="present"),\
                 mock.patch.object(m, "report") as report,\
                 mock.patch.object(m, "store_file_report") as file_report,\
@@ -4383,20 +4401,14 @@ class ReportStoreTest(WebMockTestCase):
     @mock.patch.object(market.ReportStore, "add_log")
     def test_log_market(self, add_log):
         report_store = market.ReportStore(self.m)
-        class Args:
-            def __init__(self):
-                self.debug = True
-                self.quiet = False
 
-        report_store.log_market(Args(), 4, 1, "report", True)
+        report_store.log_market(self.market_args(debug=True, quiet=False), 4, 1)
         add_log.assert_called_once_with({
             "type": "market",
             "commit": "$Format:%H$",
-            "args": { "debug": True, "quiet": False },
+            "args": { "report_path": None, "debug": True, "quiet": False },
             "user_id": 4,
             "market_id": 1,
-            "report_path": "report",
-            "debug": True
             })
 
     @mock.patch.object(market.ReportStore, "print_log")
@@ -4603,16 +4615,16 @@ class MainTest(WebMockTestCase):
             args_mock.after = "after"
             self.assertEqual("", stdout_mock.getvalue())
 
-            main.process("config", 3, 1, args_mock, "report_path", "pg_config")
+            main.process("config", 3, 1, args_mock, "pg_config")
 
             market_mock.from_config.assert_has_calls([
-                mock.call("config", args_mock, pg_config="pg_config", market_id=3, user_id=1, report_path="report_path"),
+                mock.call("config", args_mock, pg_config="pg_config", market_id=3, user_id=1),
                 mock.call().process("action", before="before", after="after"),
                 ])
 
             with self.subTest(exception=True):
                 market_mock.from_config.side_effect = Exception("boo")
-                main.process(3, "config", 1, "report_path", args_mock, "pg_config")
+                main.process(3, "config", 1, args_mock, "pg_config")
                 self.assertEqual("Exception: boo\n", stdout_mock.getvalue())
 
     def test_main(self):
@@ -4624,24 +4636,23 @@ class MainTest(WebMockTestCase):
 
                 args_mock = mock.Mock()
                 args_mock.parallel = False
-                args_mock.config = "config"
                 args_mock.user = "user"
                 parse_args.return_value = args_mock
 
-                parse_config.return_value = ["pg_config", "report_path"]
+                parse_config.return_value = "pg_config"
 
                 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]
 
                 main.main(["Foo", "Bar"])
 
                 parse_args.assert_called_with(["Foo", "Bar"])
-                parse_config.assert_called_with("config")
+                parse_config.assert_called_with(args_mock)
                 fetch_markets.assert_called_with("pg_config", "user")
 
                 self.assertEqual(2, process.call_count)
                 process.assert_has_calls([
-                    mock.call("config1", 3, 1, args_mock, "report_path", "pg_config"),
-                    mock.call("config2", 1, 2, args_mock, "report_path", "pg_config"),
+                    mock.call("config1", 3, 1, args_mock, "pg_config"),
+                    mock.call("config2", 1, 2, args_mock, "pg_config"),
                     ])
         with self.subTest(parallel=True):
             with mock.patch("main.parse_args") as parse_args,\
@@ -4652,79 +4663,66 @@ class MainTest(WebMockTestCase):
 
                 args_mock = mock.Mock()
                 args_mock.parallel = True
-                args_mock.config = "config"
                 args_mock.user = "user"
                 parse_args.return_value = args_mock
 
-                parse_config.return_value = ["pg_config", "report_path"]
+                parse_config.return_value = "pg_config"
 
                 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]
 
                 main.main(["Foo", "Bar"])
 
                 parse_args.assert_called_with(["Foo", "Bar"])
-                parse_config.assert_called_with("config")
+                parse_config.assert_called_with(args_mock)
                 fetch_markets.assert_called_with("pg_config", "user")
 
                 start.assert_called_once_with()
                 self.assertEqual(2, process.call_count)
                 process.assert_has_calls([
                     mock.call.__bool__(),
-                    mock.call("config1", 3, 1, args_mock, "report_path", "pg_config"),
+                    mock.call("config1", 3, 1, args_mock, "pg_config"),
                     mock.call.__bool__(),
-                    mock.call("config2", 1, 2, args_mock, "report_path", "pg_config"),
+                    mock.call("config2", 1, 2, args_mock, "pg_config"),
                     ])
 
     @mock.patch.object(main.sys, "exit")
-    @mock.patch("main.configparser")
     @mock.patch("main.os")
-    def test_parse_config(self, os, configparser, exit):
-        with self.subTest(pg_config=True, report_path=None):
-            config_mock = mock.MagicMock()
-            configparser.ConfigParser.return_value = config_mock
-            def config(element):
-                return element == "postgresql"
-
-            config_mock.__contains__.side_effect = config
-            config_mock.__getitem__.return_value = "pg_config"
-
-            result = main.parse_config("configfile")
-
-            config_mock.read.assert_called_with("configfile")
-
-            self.assertEqual(["pg_config", None], result)
-
-        with self.subTest(pg_config=True, report_path="present"):
-            config_mock = mock.MagicMock()
-            configparser.ConfigParser.return_value = config_mock
+    def test_parse_config(self, os, exit):
+        with self.subTest(report_path=None):
+            args = main.configargparse.Namespace(**{
+                "db_host": "host",
+                "db_port": "port",
+                "db_user": "user",
+                "db_password": "password",
+                "db_database": "database",
+                "report_path": None,
+                })
 
-            config_mock.__contains__.return_value = True
-            config_mock.__getitem__.side_effect = [
-                    {"report_path": "report_path"},
-                    {"report_path": "report_path"},
-                    "pg_config",
-                    ]
+            result = main.parse_config(args)
+            self.assertEqual({ "host": "host", "port": "port", "user":
+                "user", "password": "password", "database": "database"
+                }, result)
+            with self.assertRaises(AttributeError):
+                args.db_password
+
+        with self.subTest(report_path="present"):
+            args = main.configargparse.Namespace(**{
+                "db_host": "host",
+                "db_port": "port",
+                "db_user": "user",
+                "db_password": "password",
+                "db_database": "database",
+                "report_path": "report_path",
+                })
 
             os.path.exists.return_value = False
-            result = main.parse_config("configfile")
 
-            config_mock.read.assert_called_with("configfile")
-            self.assertEqual(["pg_config", "report_path"], result)
+            result = main.parse_config(args)
+
             os.path.exists.assert_called_once_with("report_path")
             os.makedirs.assert_called_once_with("report_path")
 
-        with self.subTest(pg_config=False),\
-                mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock:
-            config_mock = mock.MagicMock()
-            configparser.ConfigParser.return_value = config_mock
-            result = main.parse_config("configfile")
-
-            config_mock.read.assert_called_with("configfile")
-            exit.assert_called_once_with(1)
-            self.assertEqual("no configuration for postgresql in config file\n", stdout_mock.getvalue())
-
-    @mock.patch.object(main.sys, "exit")
-    def test_parse_args(self, exit):
+    def test_parse_args(self):
         with self.subTest(config="config.ini"):
             args = main.parse_args([])
             self.assertEqual("config.ini", args.config)
@@ -4737,13 +4735,10 @@ class MainTest(WebMockTestCase):
             self.assertTrue(args.after)
             self.assertTrue(args.debug)
 
-            exit.assert_not_called()
-
-        with self.subTest(config="inexistant"),\
-                mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock:
+        with self.subTest(config="inexistant"), \
+                self.assertRaises(SystemExit), \
+                mock.patch('sys.stderr', new_callable=StringIO) as stdout_mock:
             args = main.parse_args(["--config", "foo.bar"])
-            exit.assert_called_once_with(1)
-            self.assertEqual("no config file found, exiting\n", stdout_mock.getvalue())
 
     @mock.patch.object(main, "psycopg2")
     def test_fetch_markets(self, psycopg2):