diff options
author | Ismaël Bouya <ismael.bouya@normalesup.org> | 2018-04-05 09:56:51 +0200 |
---|---|---|
committer | Ismaël Bouya <ismael.bouya@normalesup.org> | 2018-04-05 09:56:51 +0200 |
commit | a42d6cc8a49e82d851cde587fbc938b3b6364f63 (patch) | |
tree | e01688e2a4d72ffa0aaf6f278906cac9d716d0d1 | |
parent | 3b60291066e5442ce2980a6c40ea10542f24a910 (diff) | |
download | Trader-a42d6cc8a49e82d851cde587fbc938b3b6364f63.tar.gz Trader-a42d6cc8a49e82d851cde587fbc938b3b6364f63.tar.zst Trader-a42d6cc8a49e82d851cde587fbc938b3b6364f63.zip |
Refactor config parsing
-rw-r--r-- | config.ini | 12 | ||||
-rw-r--r-- | main.py | 82 | ||||
-rw-r--r-- | market.py | 11 | ||||
-rw-r--r-- | requirements.txt | 1 | ||||
-rw-r--r-- | store.py | 4 | ||||
-rw-r--r-- | test.py | 149 |
6 files changed, 133 insertions, 126 deletions
@@ -1,9 +1,9 @@ | |||
1 | [postgresql] | 1 | [postgresql] |
2 | host = localhost | 2 | db-host = localhost |
3 | port = 5432 | 3 | db-port = 5432 |
4 | user = cryptoportfolio | 4 | db-user = cryptoportfolio |
5 | password = cryptoportfolio | 5 | db-password = cryptoportfolio |
6 | database = cryptoportfolio | 6 | db-database = cryptoportfolio |
7 | 7 | ||
8 | [app] | 8 | [app] |
9 | report_path = reports | 9 | report-path = reports |
@@ -1,6 +1,5 @@ | |||
1 | from datetime import datetime | 1 | from datetime import datetime |
2 | import argparse | 2 | import configargparse |
3 | import configparser | ||
4 | import psycopg2 | 3 | import psycopg2 |
5 | import os | 4 | import os |
6 | import sys | 5 | import sys |
@@ -80,31 +79,35 @@ def fetch_markets(pg_config, user): | |||
80 | for row in cursor: | 79 | for row in cursor: |
81 | yield row | 80 | yield row |
82 | 81 | ||
83 | def parse_config(config_file): | 82 | def parse_config(args): |
84 | config = configparser.ConfigParser() | 83 | pg_config = { |
85 | config.read(config_file) | 84 | "host": args.db_host, |
85 | "port": args.db_port, | ||
86 | "user": args.db_user, | ||
87 | "password": args.db_password, | ||
88 | "database": args.db_database, | ||
89 | } | ||
90 | del(args.db_host) | ||
91 | del(args.db_port) | ||
92 | del(args.db_user) | ||
93 | del(args.db_password) | ||
94 | del(args.db_database) | ||
86 | 95 | ||
87 | if "postgresql" not in config: | 96 | report_path = args.report_path |
88 | print("no configuration for postgresql in config file") | ||
89 | sys.exit(1) | ||
90 | 97 | ||
91 | if "app" in config and "report_path" in config["app"]: | 98 | if report_path is not None and not \ |
92 | report_path = config["app"]["report_path"] | 99 | os.path.exists(report_path): |
100 | os.makedirs(report_path) | ||
93 | 101 | ||
94 | if not os.path.exists(report_path): | 102 | return pg_config |
95 | os.makedirs(report_path) | ||
96 | else: | ||
97 | report_path = None | ||
98 | |||
99 | return [config["postgresql"], report_path] | ||
100 | 103 | ||
101 | def parse_args(argv): | 104 | def parse_args(argv): |
102 | parser = argparse.ArgumentParser( | 105 | parser = configargparse.ArgumentParser( |
103 | description="Run the trade bot") | 106 | description="Run the trade bot.") |
104 | 107 | ||
105 | parser.add_argument("-c", "--config", | 108 | parser.add_argument("-c", "--config", |
106 | default="config.ini", | 109 | default="config.ini", |
107 | required=False, | 110 | required=False, is_config_file=True, |
108 | help="Config file to load (default: config.ini)") | 111 | help="Config file to load (default: config.ini)") |
109 | parser.add_argument("--before", | 112 | parser.add_argument("--before", |
110 | default=False, action='store_const', const=True, | 113 | default=False, action='store_const', const=True, |
@@ -125,21 +128,32 @@ def parse_args(argv): | |||
125 | help="Do a different action than trading (add several times to chain)") | 128 | help="Do a different action than trading (add several times to chain)") |
126 | parser.add_argument("--parallel", action='store_true', default=True, dest="parallel") | 129 | parser.add_argument("--parallel", action='store_true', default=True, dest="parallel") |
127 | parser.add_argument("--no-parallel", action='store_false', dest="parallel") | 130 | parser.add_argument("--no-parallel", action='store_false', dest="parallel") |
128 | 131 | parser.add_argument("--report-db", action='store_true', default=True, dest="report_db", | |
129 | args = parser.parse_args(argv) | 132 | help="Store report to database (default)") |
130 | 133 | parser.add_argument("--no-report-db", action='store_false', dest="report_db", | |
131 | if not os.path.exists(args.config): | 134 | help="Don't store report to database") |
132 | print("no config file found, exiting") | 135 | parser.add_argument("--report-path", required=False, |
133 | sys.exit(1) | 136 | help="Where to store the reports (default: absent, don't store)") |
134 | 137 | parser.add_argument("--no-report-path", action='store_const', dest='report_path', const=None, | |
135 | return args | 138 | help="Don't store the report to file (default)") |
136 | 139 | parser.add_argument("--db-host", default="localhost", | |
137 | def process(market_config, market_id, user_id, args, report_path, pg_config): | 140 | help="Host access to database (default: localhost)") |
141 | parser.add_argument("--db-port", default=5432, | ||
142 | help="Port access to database (default: 5432)") | ||
143 | parser.add_argument("--db-user", default="cryptoportfolio", | ||
144 | help="User access to database (default: cryptoportfolio)") | ||
145 | parser.add_argument("--db-password", default="cryptoportfolio", | ||
146 | help="Password access to database (default: cryptoportfolio)") | ||
147 | parser.add_argument("--db-database", default="cryptoportfolio", | ||
148 | help="Database access to database (default: cryptoportfolio)") | ||
149 | |||
150 | return parser.parse_args(argv) | ||
151 | |||
152 | def process(market_config, market_id, user_id, args, pg_config): | ||
138 | try: | 153 | try: |
139 | market.Market\ | 154 | market.Market\ |
140 | .from_config(market_config, args, | 155 | .from_config(market_config, args, market_id=market_id, |
141 | pg_config=pg_config, market_id=market_id, | 156 | pg_config=pg_config, user_id=user_id)\ |
142 | user_id=user_id, report_path=report_path)\ | ||
143 | .process(args.action, before=args.before, after=args.after) | 157 | .process(args.action, before=args.before, after=args.after) |
144 | except Exception as e: | 158 | except Exception as e: |
145 | print("{}: {}".format(e.__class__.__name__, e)) | 159 | print("{}: {}".format(e.__class__.__name__, e)) |
@@ -147,7 +161,7 @@ def process(market_config, market_id, user_id, args, report_path, pg_config): | |||
147 | def main(argv): | 161 | def main(argv): |
148 | args = parse_args(argv) | 162 | args = parse_args(argv) |
149 | 163 | ||
150 | pg_config, report_path = parse_config(args.config) | 164 | pg_config = parse_config(args) |
151 | 165 | ||
152 | if args.parallel: | 166 | if args.parallel: |
153 | import threading | 167 | import threading |
@@ -159,7 +173,7 @@ def main(argv): | |||
159 | process_ = process | 173 | process_ = process |
160 | 174 | ||
161 | for market_id, market_config, user_id in fetch_markets(pg_config, args.user): | 175 | for market_id, market_config, user_id in fetch_markets(pg_config, args.user): |
162 | process_(market_config, market_id, user_id, args, report_path, pg_config) | 176 | process_(market_config, market_id, user_id, args, pg_config) |
163 | 177 | ||
164 | if __name__ == '__main__': # pragma: no cover | 178 | if __name__ == '__main__': # pragma: no cover |
165 | main(sys.argv[1:]) | 179 | main(sys.argv[1:]) |
@@ -25,11 +25,10 @@ class Market: | |||
25 | self.balances = BalanceStore(self) | 25 | self.balances = BalanceStore(self) |
26 | self.processor = Processor(self) | 26 | self.processor = Processor(self) |
27 | 27 | ||
28 | for key in ["user_id", "market_id", "report_path", "pg_config"]: | 28 | for key in ["user_id", "market_id", "pg_config"]: |
29 | setattr(self, key, kwargs.get(key, None)) | 29 | setattr(self, key, kwargs.get(key, None)) |
30 | 30 | ||
31 | self.report.log_market(self.args, self.user_id, self.market_id, | 31 | self.report.log_market(self.args, self.user_id, self.market_id) |
32 | self.report_path, self.debug) | ||
33 | 32 | ||
34 | @classmethod | 33 | @classmethod |
35 | def from_config(cls, config, args, **kwargs): | 34 | def from_config(cls, config, args, **kwargs): |
@@ -42,14 +41,14 @@ class Market: | |||
42 | def store_report(self): | 41 | def store_report(self): |
43 | self.report.merge(Portfolio.report) | 42 | self.report.merge(Portfolio.report) |
44 | date = datetime.now() | 43 | date = datetime.now() |
45 | if self.report_path is not None: | 44 | if self.args.report_path is not None: |
46 | self.store_file_report(date) | 45 | self.store_file_report(date) |
47 | if self.pg_config is not None: | 46 | if self.pg_config is not None and self.args.report_db: |
48 | self.store_database_report(date) | 47 | self.store_database_report(date) |
49 | 48 | ||
50 | def store_file_report(self, date): | 49 | def store_file_report(self, date): |
51 | try: | 50 | try: |
52 | report_file = "{}/{}_{}".format(self.report_path, date.isoformat(), self.user_id) | 51 | report_file = "{}/{}_{}".format(self.args.report_path, date.isoformat(), self.user_id) |
53 | with open(report_file + ".json", "w") as f: | 52 | with open(report_file + ".json", "w") as f: |
54 | f.write(self.report.to_json()) | 53 | f.write(self.report.to_json()) |
55 | with open(report_file + ".log", "w") as f: | 54 | with open(report_file + ".log", "w") as f: |
diff --git a/requirements.txt b/requirements.txt index 1bc76ec..2451c80 100644 --- a/requirements.txt +++ b/requirements.txt | |||
@@ -5,3 +5,4 @@ requests_mock==1.4.0 | |||
5 | psycopg2==2.7.4 | 5 | psycopg2==2.7.4 |
6 | retry==0.9.2 | 6 | retry==0.9.2 |
7 | cachetools==2.0.1 | 7 | cachetools==2.0.1 |
8 | configargparse==0.12.0 | ||
@@ -222,15 +222,13 @@ class ReportStore: | |||
222 | "action": action, | 222 | "action": action, |
223 | }) | 223 | }) |
224 | 224 | ||
225 | def log_market(self, args, user_id, market_id, report_path, debug): | 225 | def log_market(self, args, user_id, market_id): |
226 | self.add_log({ | 226 | self.add_log({ |
227 | "type": "market", | 227 | "type": "market", |
228 | "commit": "$Format:%H$", | 228 | "commit": "$Format:%H$", |
229 | "args": vars(args), | 229 | "args": vars(args), |
230 | "user_id": user_id, | 230 | "user_id": user_id, |
231 | "market_id": market_id, | 231 | "market_id": market_id, |
232 | "report_path": report_path, | ||
233 | "debug": debug, | ||
234 | }) | 232 | }) |
235 | 233 | ||
236 | class BalanceStore: | 234 | class BalanceStore: |
@@ -23,8 +23,9 @@ for test_type in limits: | |||
23 | class WebMockTestCase(unittest.TestCase): | 23 | class WebMockTestCase(unittest.TestCase): |
24 | import time | 24 | import time |
25 | 25 | ||
26 | def market_args(self, debug=False, quiet=False): | 26 | def market_args(self, debug=False, quiet=False, report_path=None, **kwargs): |
27 | return type('Args', (object,), { "debug": debug, "quiet": quiet })() | 27 | return main.configargparse.Namespace(report_path=report_path, |
28 | debug=debug, quiet=quiet, **kwargs) | ||
28 | 29 | ||
29 | def setUp(self): | 30 | def setUp(self): |
30 | super().setUp() | 31 | super().setUp() |
@@ -1632,7 +1633,8 @@ class MarketTest(WebMockTestCase): | |||
1632 | 1633 | ||
1633 | def test_store_file_report(self): | 1634 | def test_store_file_report(self): |
1634 | file_open = mock.mock_open() | 1635 | file_open = mock.mock_open() |
1635 | m = market.Market(self.ccxt, self.market_args(), report_path="present", user_id=1) | 1636 | m = market.Market(self.ccxt, |
1637 | self.market_args(report_path="present"), user_id=1) | ||
1636 | with self.subTest(file="present"),\ | 1638 | with self.subTest(file="present"),\ |
1637 | mock.patch("market.open", file_open),\ | 1639 | mock.patch("market.open", file_open),\ |
1638 | mock.patch.object(m, "report") as report,\ | 1640 | mock.patch.object(m, "report") as report,\ |
@@ -1649,7 +1651,7 @@ class MarketTest(WebMockTestCase): | |||
1649 | file_open().write.assert_any_call("Foo\nBar") | 1651 | file_open().write.assert_any_call("Foo\nBar") |
1650 | m.report.to_json.assert_called_once_with() | 1652 | m.report.to_json.assert_called_once_with() |
1651 | 1653 | ||
1652 | m = market.Market(self.ccxt, self.market_args(), report_path="error", user_id=1) | 1654 | m = market.Market(self.ccxt, self.market_args(report_path="error"), user_id=1) |
1653 | with self.subTest(file="error"),\ | 1655 | with self.subTest(file="error"),\ |
1654 | mock.patch("market.open") as file_open,\ | 1656 | mock.patch("market.open") as file_open,\ |
1655 | mock.patch.object(m, "report") as report,\ | 1657 | mock.patch.object(m, "report") as report,\ |
@@ -1697,7 +1699,7 @@ class MarketTest(WebMockTestCase): | |||
1697 | self.assertEqual(stdout_mock.getvalue(), "impossible to store report to database: Exception; Bouh\n") | 1699 | self.assertEqual(stdout_mock.getvalue(), "impossible to store report to database: Exception; Bouh\n") |
1698 | 1700 | ||
1699 | def test_store_report(self): | 1701 | def test_store_report(self): |
1700 | m = market.Market(self.ccxt, self.market_args(), user_id=1) | 1702 | m = market.Market(self.ccxt, self.market_args(report_db=False), user_id=1) |
1701 | with self.subTest(file=None, pg_config=None),\ | 1703 | with self.subTest(file=None, pg_config=None),\ |
1702 | mock.patch.object(m, "report") as report,\ | 1704 | mock.patch.object(m, "report") as report,\ |
1703 | mock.patch.object(m, "store_database_report") as db_report,\ | 1705 | mock.patch.object(m, "store_database_report") as db_report,\ |
@@ -1709,7 +1711,7 @@ class MarketTest(WebMockTestCase): | |||
1709 | db_report.assert_not_called() | 1711 | db_report.assert_not_called() |
1710 | 1712 | ||
1711 | report.reset_mock() | 1713 | report.reset_mock() |
1712 | m = market.Market(self.ccxt, self.market_args(), report_path="present", user_id=1) | 1714 | m = market.Market(self.ccxt, self.market_args(report_db=False, report_path="present"), user_id=1) |
1713 | with self.subTest(file="present", pg_config=None),\ | 1715 | with self.subTest(file="present", pg_config=None),\ |
1714 | mock.patch.object(m, "report") as report,\ | 1716 | mock.patch.object(m, "report") as report,\ |
1715 | mock.patch.object(m, "store_file_report") as file_report,\ | 1717 | mock.patch.object(m, "store_file_report") as file_report,\ |
@@ -1725,7 +1727,23 @@ class MarketTest(WebMockTestCase): | |||
1725 | db_report.assert_not_called() | 1727 | db_report.assert_not_called() |
1726 | 1728 | ||
1727 | report.reset_mock() | 1729 | report.reset_mock() |
1728 | m = market.Market(self.ccxt, self.market_args(), pg_config="present", user_id=1) | 1730 | m = market.Market(self.ccxt, self.market_args(report_db=True, report_path="present"), user_id=1) |
1731 | with self.subTest(file="present", pg_config=None, report_db=True),\ | ||
1732 | mock.patch.object(m, "report") as report,\ | ||
1733 | mock.patch.object(m, "store_file_report") as file_report,\ | ||
1734 | mock.patch.object(m, "store_database_report") as db_report,\ | ||
1735 | mock.patch.object(market, "datetime") as time_mock: | ||
1736 | |||
1737 | time_mock.now.return_value = datetime.datetime(2018, 2, 25) | ||
1738 | |||
1739 | m.store_report() | ||
1740 | |||
1741 | report.merge.assert_called_with(store.Portfolio.report) | ||
1742 | file_report.assert_called_once_with(datetime.datetime(2018, 2, 25)) | ||
1743 | db_report.assert_not_called() | ||
1744 | |||
1745 | report.reset_mock() | ||
1746 | m = market.Market(self.ccxt, self.market_args(report_db=True), pg_config="present", user_id=1) | ||
1729 | with self.subTest(file=None, pg_config="present"),\ | 1747 | with self.subTest(file=None, pg_config="present"),\ |
1730 | mock.patch.object(m, "report") as report,\ | 1748 | mock.patch.object(m, "report") as report,\ |
1731 | mock.patch.object(m, "store_file_report") as file_report,\ | 1749 | mock.patch.object(m, "store_file_report") as file_report,\ |
@@ -1741,8 +1759,8 @@ class MarketTest(WebMockTestCase): | |||
1741 | db_report.assert_called_once_with(datetime.datetime(2018, 2, 25)) | 1759 | db_report.assert_called_once_with(datetime.datetime(2018, 2, 25)) |
1742 | 1760 | ||
1743 | report.reset_mock() | 1761 | report.reset_mock() |
1744 | m = market.Market(self.ccxt, self.market_args(), | 1762 | m = market.Market(self.ccxt, self.market_args(report_db=True, report_path="present"), |
1745 | pg_config="pg_config", report_path="present", user_id=1) | 1763 | pg_config="pg_config", user_id=1) |
1746 | with self.subTest(file="present", pg_config="present"),\ | 1764 | with self.subTest(file="present", pg_config="present"),\ |
1747 | mock.patch.object(m, "report") as report,\ | 1765 | mock.patch.object(m, "report") as report,\ |
1748 | mock.patch.object(m, "store_file_report") as file_report,\ | 1766 | mock.patch.object(m, "store_file_report") as file_report,\ |
@@ -4383,20 +4401,14 @@ class ReportStoreTest(WebMockTestCase): | |||
4383 | @mock.patch.object(market.ReportStore, "add_log") | 4401 | @mock.patch.object(market.ReportStore, "add_log") |
4384 | def test_log_market(self, add_log): | 4402 | def test_log_market(self, add_log): |
4385 | report_store = market.ReportStore(self.m) | 4403 | report_store = market.ReportStore(self.m) |
4386 | class Args: | ||
4387 | def __init__(self): | ||
4388 | self.debug = True | ||
4389 | self.quiet = False | ||
4390 | 4404 | ||
4391 | report_store.log_market(Args(), 4, 1, "report", True) | 4405 | report_store.log_market(self.market_args(debug=True, quiet=False), 4, 1) |
4392 | add_log.assert_called_once_with({ | 4406 | add_log.assert_called_once_with({ |
4393 | "type": "market", | 4407 | "type": "market", |
4394 | "commit": "$Format:%H$", | 4408 | "commit": "$Format:%H$", |
4395 | "args": { "debug": True, "quiet": False }, | 4409 | "args": { "report_path": None, "debug": True, "quiet": False }, |
4396 | "user_id": 4, | 4410 | "user_id": 4, |
4397 | "market_id": 1, | 4411 | "market_id": 1, |
4398 | "report_path": "report", | ||
4399 | "debug": True | ||
4400 | }) | 4412 | }) |
4401 | 4413 | ||
4402 | @mock.patch.object(market.ReportStore, "print_log") | 4414 | @mock.patch.object(market.ReportStore, "print_log") |
@@ -4603,16 +4615,16 @@ class MainTest(WebMockTestCase): | |||
4603 | args_mock.after = "after" | 4615 | args_mock.after = "after" |
4604 | self.assertEqual("", stdout_mock.getvalue()) | 4616 | self.assertEqual("", stdout_mock.getvalue()) |
4605 | 4617 | ||
4606 | main.process("config", 3, 1, args_mock, "report_path", "pg_config") | 4618 | main.process("config", 3, 1, args_mock, "pg_config") |
4607 | 4619 | ||
4608 | market_mock.from_config.assert_has_calls([ | 4620 | market_mock.from_config.assert_has_calls([ |
4609 | mock.call("config", args_mock, pg_config="pg_config", market_id=3, user_id=1, report_path="report_path"), | 4621 | mock.call("config", args_mock, pg_config="pg_config", market_id=3, user_id=1), |
4610 | mock.call().process("action", before="before", after="after"), | 4622 | mock.call().process("action", before="before", after="after"), |
4611 | ]) | 4623 | ]) |
4612 | 4624 | ||
4613 | with self.subTest(exception=True): | 4625 | with self.subTest(exception=True): |
4614 | market_mock.from_config.side_effect = Exception("boo") | 4626 | market_mock.from_config.side_effect = Exception("boo") |
4615 | main.process(3, "config", 1, "report_path", args_mock, "pg_config") | 4627 | main.process(3, "config", 1, args_mock, "pg_config") |
4616 | self.assertEqual("Exception: boo\n", stdout_mock.getvalue()) | 4628 | self.assertEqual("Exception: boo\n", stdout_mock.getvalue()) |
4617 | 4629 | ||
4618 | def test_main(self): | 4630 | def test_main(self): |
@@ -4624,24 +4636,23 @@ class MainTest(WebMockTestCase): | |||
4624 | 4636 | ||
4625 | args_mock = mock.Mock() | 4637 | args_mock = mock.Mock() |
4626 | args_mock.parallel = False | 4638 | args_mock.parallel = False |
4627 | args_mock.config = "config" | ||
4628 | args_mock.user = "user" | 4639 | args_mock.user = "user" |
4629 | parse_args.return_value = args_mock | 4640 | parse_args.return_value = args_mock |
4630 | 4641 | ||
4631 | parse_config.return_value = ["pg_config", "report_path"] | 4642 | parse_config.return_value = "pg_config" |
4632 | 4643 | ||
4633 | fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] | 4644 | fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] |
4634 | 4645 | ||
4635 | main.main(["Foo", "Bar"]) | 4646 | main.main(["Foo", "Bar"]) |
4636 | 4647 | ||
4637 | parse_args.assert_called_with(["Foo", "Bar"]) | 4648 | parse_args.assert_called_with(["Foo", "Bar"]) |
4638 | parse_config.assert_called_with("config") | 4649 | parse_config.assert_called_with(args_mock) |
4639 | fetch_markets.assert_called_with("pg_config", "user") | 4650 | fetch_markets.assert_called_with("pg_config", "user") |
4640 | 4651 | ||
4641 | self.assertEqual(2, process.call_count) | 4652 | self.assertEqual(2, process.call_count) |
4642 | process.assert_has_calls([ | 4653 | process.assert_has_calls([ |
4643 | mock.call("config1", 3, 1, args_mock, "report_path", "pg_config"), | 4654 | mock.call("config1", 3, 1, args_mock, "pg_config"), |
4644 | mock.call("config2", 1, 2, args_mock, "report_path", "pg_config"), | 4655 | mock.call("config2", 1, 2, args_mock, "pg_config"), |
4645 | ]) | 4656 | ]) |
4646 | with self.subTest(parallel=True): | 4657 | with self.subTest(parallel=True): |
4647 | with mock.patch("main.parse_args") as parse_args,\ | 4658 | with mock.patch("main.parse_args") as parse_args,\ |
@@ -4652,79 +4663,66 @@ class MainTest(WebMockTestCase): | |||
4652 | 4663 | ||
4653 | args_mock = mock.Mock() | 4664 | args_mock = mock.Mock() |
4654 | args_mock.parallel = True | 4665 | args_mock.parallel = True |
4655 | args_mock.config = "config" | ||
4656 | args_mock.user = "user" | 4666 | args_mock.user = "user" |
4657 | parse_args.return_value = args_mock | 4667 | parse_args.return_value = args_mock |
4658 | 4668 | ||
4659 | parse_config.return_value = ["pg_config", "report_path"] | 4669 | parse_config.return_value = "pg_config" |
4660 | 4670 | ||
4661 | fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] | 4671 | fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] |
4662 | 4672 | ||
4663 | main.main(["Foo", "Bar"]) | 4673 | main.main(["Foo", "Bar"]) |
4664 | 4674 | ||
4665 | parse_args.assert_called_with(["Foo", "Bar"]) | 4675 | parse_args.assert_called_with(["Foo", "Bar"]) |
4666 | parse_config.assert_called_with("config") | 4676 | parse_config.assert_called_with(args_mock) |
4667 | fetch_markets.assert_called_with("pg_config", "user") | 4677 | fetch_markets.assert_called_with("pg_config", "user") |
4668 | 4678 | ||
4669 | start.assert_called_once_with() | 4679 | start.assert_called_once_with() |
4670 | self.assertEqual(2, process.call_count) | 4680 | self.assertEqual(2, process.call_count) |
4671 | process.assert_has_calls([ | 4681 | process.assert_has_calls([ |
4672 | mock.call.__bool__(), | 4682 | mock.call.__bool__(), |
4673 | mock.call("config1", 3, 1, args_mock, "report_path", "pg_config"), | 4683 | mock.call("config1", 3, 1, args_mock, "pg_config"), |
4674 | mock.call.__bool__(), | 4684 | mock.call.__bool__(), |
4675 | mock.call("config2", 1, 2, args_mock, "report_path", "pg_config"), | 4685 | mock.call("config2", 1, 2, args_mock, "pg_config"), |
4676 | ]) | 4686 | ]) |
4677 | 4687 | ||
4678 | @mock.patch.object(main.sys, "exit") | 4688 | @mock.patch.object(main.sys, "exit") |
4679 | @mock.patch("main.configparser") | ||
4680 | @mock.patch("main.os") | 4689 | @mock.patch("main.os") |
4681 | def test_parse_config(self, os, configparser, exit): | 4690 | def test_parse_config(self, os, exit): |
4682 | with self.subTest(pg_config=True, report_path=None): | 4691 | with self.subTest(report_path=None): |
4683 | config_mock = mock.MagicMock() | 4692 | args = main.configargparse.Namespace(**{ |
4684 | configparser.ConfigParser.return_value = config_mock | 4693 | "db_host": "host", |
4685 | def config(element): | 4694 | "db_port": "port", |
4686 | return element == "postgresql" | 4695 | "db_user": "user", |
4687 | 4696 | "db_password": "password", | |
4688 | config_mock.__contains__.side_effect = config | 4697 | "db_database": "database", |
4689 | config_mock.__getitem__.return_value = "pg_config" | 4698 | "report_path": None, |
4690 | 4699 | }) | |
4691 | result = main.parse_config("configfile") | ||
4692 | |||
4693 | config_mock.read.assert_called_with("configfile") | ||
4694 | |||
4695 | self.assertEqual(["pg_config", None], result) | ||
4696 | |||
4697 | with self.subTest(pg_config=True, report_path="present"): | ||
4698 | config_mock = mock.MagicMock() | ||
4699 | configparser.ConfigParser.return_value = config_mock | ||
4700 | 4700 | ||
4701 | config_mock.__contains__.return_value = True | 4701 | result = main.parse_config(args) |
4702 | config_mock.__getitem__.side_effect = [ | 4702 | self.assertEqual({ "host": "host", "port": "port", "user": |
4703 | {"report_path": "report_path"}, | 4703 | "user", "password": "password", "database": "database" |
4704 | {"report_path": "report_path"}, | 4704 | }, result) |
4705 | "pg_config", | 4705 | with self.assertRaises(AttributeError): |
4706 | ] | 4706 | args.db_password |
4707 | |||
4708 | with self.subTest(report_path="present"): | ||
4709 | args = main.configargparse.Namespace(**{ | ||
4710 | "db_host": "host", | ||
4711 | "db_port": "port", | ||
4712 | "db_user": "user", | ||
4713 | "db_password": "password", | ||
4714 | "db_database": "database", | ||
4715 | "report_path": "report_path", | ||
4716 | }) | ||
4707 | 4717 | ||
4708 | os.path.exists.return_value = False | 4718 | os.path.exists.return_value = False |
4709 | result = main.parse_config("configfile") | ||
4710 | 4719 | ||
4711 | config_mock.read.assert_called_with("configfile") | 4720 | result = main.parse_config(args) |
4712 | self.assertEqual(["pg_config", "report_path"], result) | 4721 | |
4713 | os.path.exists.assert_called_once_with("report_path") | 4722 | os.path.exists.assert_called_once_with("report_path") |
4714 | os.makedirs.assert_called_once_with("report_path") | 4723 | os.makedirs.assert_called_once_with("report_path") |
4715 | 4724 | ||
4716 | with self.subTest(pg_config=False),\ | 4725 | def test_parse_args(self): |
4717 | mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock: | ||
4718 | config_mock = mock.MagicMock() | ||
4719 | configparser.ConfigParser.return_value = config_mock | ||
4720 | result = main.parse_config("configfile") | ||
4721 | |||
4722 | config_mock.read.assert_called_with("configfile") | ||
4723 | exit.assert_called_once_with(1) | ||
4724 | self.assertEqual("no configuration for postgresql in config file\n", stdout_mock.getvalue()) | ||
4725 | |||
4726 | @mock.patch.object(main.sys, "exit") | ||
4727 | def test_parse_args(self, exit): | ||
4728 | with self.subTest(config="config.ini"): | 4726 | with self.subTest(config="config.ini"): |
4729 | args = main.parse_args([]) | 4727 | args = main.parse_args([]) |
4730 | self.assertEqual("config.ini", args.config) | 4728 | self.assertEqual("config.ini", args.config) |
@@ -4737,13 +4735,10 @@ class MainTest(WebMockTestCase): | |||
4737 | self.assertTrue(args.after) | 4735 | self.assertTrue(args.after) |
4738 | self.assertTrue(args.debug) | 4736 | self.assertTrue(args.debug) |
4739 | 4737 | ||
4740 | exit.assert_not_called() | 4738 | with self.subTest(config="inexistant"), \ |
4741 | 4739 | self.assertRaises(SystemExit), \ | |
4742 | with self.subTest(config="inexistant"),\ | 4740 | mock.patch('sys.stderr', new_callable=StringIO) as stdout_mock: |
4743 | mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock: | ||
4744 | args = main.parse_args(["--config", "foo.bar"]) | 4741 | args = main.parse_args(["--config", "foo.bar"]) |
4745 | exit.assert_called_once_with(1) | ||
4746 | self.assertEqual("no config file found, exiting\n", stdout_mock.getvalue()) | ||
4747 | 4742 | ||
4748 | @mock.patch.object(main, "psycopg2") | 4743 | @mock.patch.object(main, "psycopg2") |
4749 | def test_fetch_markets(self, psycopg2): | 4744 | def test_fetch_markets(self, psycopg2): |