diff options
-rw-r--r-- | main.py | 22 | ||||
-rw-r--r-- | market.py | 49 | ||||
-rw-r--r-- | store.py | 20 | ||||
-rw-r--r-- | test.py | 165 |
4 files changed, 202 insertions, 54 deletions
@@ -62,7 +62,7 @@ def make_order(market, value, currency, action="acquire", | |||
62 | 62 | ||
63 | def get_user_market(config_path, user_id, debug=False): | 63 | def get_user_market(config_path, user_id, debug=False): |
64 | pg_config, report_path = parse_config(config_path) | 64 | pg_config, report_path = parse_config(config_path) |
65 | market_config = list(fetch_markets(pg_config, str(user_id)))[0][0] | 65 | market_config = list(fetch_markets(pg_config, str(user_id)))[0][1] |
66 | args = type('Args', (object,), { "debug": debug, "quiet": False })() | 66 | args = type('Args', (object,), { "debug": debug, "quiet": False })() |
67 | return market.Market.from_config(market_config, args, user_id=user_id, report_path=report_path) | 67 | return market.Market.from_config(market_config, args, user_id=user_id, report_path=report_path) |
68 | 68 | ||
@@ -71,9 +71,9 @@ def fetch_markets(pg_config, user): | |||
71 | cursor = connection.cursor() | 71 | cursor = connection.cursor() |
72 | 72 | ||
73 | if user is None: | 73 | if user is None: |
74 | cursor.execute("SELECT config,user_id FROM market_configs") | 74 | cursor.execute("SELECT id,config,user_id FROM market_configs") |
75 | else: | 75 | else: |
76 | cursor.execute("SELECT config,user_id FROM market_configs WHERE user_id = %s", user) | 76 | cursor.execute("SELECT id,config,user_id FROM market_configs WHERE user_id = %s", user) |
77 | 77 | ||
78 | for row in cursor: | 78 | for row in cursor: |
79 | yield row | 79 | yield row |
@@ -132,10 +132,12 @@ def parse_args(argv): | |||
132 | 132 | ||
133 | return args | 133 | return args |
134 | 134 | ||
135 | def process(market_config, user_id, report_path, args): | 135 | def process(market_id, market_config, user_id, report_path, args, pg_config): |
136 | try: | 136 | try: |
137 | market.Market\ | 137 | market.Market\ |
138 | .from_config(market_config, args, user_id=user_id, report_path=report_path)\ | 138 | .from_config(market_config, args, |
139 | pg_config=pg_config, market_id=market_id, | ||
140 | user_id=user_id, report_path=report_path)\ | ||
139 | .process(args.action, before=args.before, after=args.after) | 141 | .process(args.action, before=args.before, after=args.after) |
140 | except Exception as e: | 142 | except Exception as e: |
141 | print("{}: {}".format(e.__class__.__name__, e)) | 143 | print("{}: {}".format(e.__class__.__name__, e)) |
@@ -149,11 +151,13 @@ def main(argv): | |||
149 | import threading | 151 | import threading |
150 | market.Portfolio.start_worker() | 152 | market.Portfolio.start_worker() |
151 | 153 | ||
152 | for market_config, user_id in fetch_markets(pg_config, args.user): | 154 | for row in fetch_markets(pg_config, args.user): |
153 | threading.Thread(target=process, args=[market_config, user_id, report_path, args]).start() | 155 | threading.Thread(target=process, args=[ |
156 | *row, report_path, args, pg_config | ||
157 | ]).start() | ||
154 | else: | 158 | else: |
155 | for market_config, user_id in fetch_markets(pg_config, args.user): | 159 | for row in fetch_markets(pg_config, args.user): |
156 | process(market_config, user_id, report_path, args) | 160 | process(*row, report_path, args, pg_config) |
157 | 161 | ||
158 | if __name__ == '__main__': # pragma: no cover | 162 | if __name__ == '__main__': # pragma: no cover |
159 | main(sys.argv[1:]) | 163 | main(sys.argv[1:]) |
@@ -1,6 +1,7 @@ | |||
1 | from ccxt import ExchangeError, NotSupported | 1 | from ccxt import ExchangeError, NotSupported |
2 | import ccxt_wrapper as ccxt | 2 | import ccxt_wrapper as ccxt |
3 | import time | 3 | import time |
4 | import psycopg2 | ||
4 | from store import * | 5 | from store import * |
5 | from cachetools.func import ttl_cache | 6 | from cachetools.func import ttl_cache |
6 | from datetime import datetime | 7 | from datetime import datetime |
@@ -13,7 +14,9 @@ class Market: | |||
13 | trades = None | 14 | trades = None |
14 | balances = None | 15 | balances = None |
15 | 16 | ||
16 | def __init__(self, ccxt_instance, args, user_id=None, report_path=None): | 17 | def __init__(self, ccxt_instance, args, |
18 | user_id=None, market_id=None, | ||
19 | report_path=None, pg_config=None): | ||
17 | self.args = args | 20 | self.args = args |
18 | self.debug = args.debug | 21 | self.debug = args.debug |
19 | self.ccxt = ccxt_instance | 22 | self.ccxt = ccxt_instance |
@@ -24,10 +27,13 @@ class Market: | |||
24 | self.processor = Processor(self) | 27 | self.processor = Processor(self) |
25 | 28 | ||
26 | self.user_id = user_id | 29 | self.user_id = user_id |
30 | self.market_id = market_id | ||
27 | self.report_path = report_path | 31 | self.report_path = report_path |
32 | self.pg_config = pg_config | ||
28 | 33 | ||
29 | @classmethod | 34 | @classmethod |
30 | def from_config(cls, config, args, user_id=None, report_path=None): | 35 | def from_config(cls, config, args, |
36 | user_id=None, market_id=None, report_path=None, pg_config=None): | ||
31 | config["apiKey"] = config.pop("key", None) | 37 | config["apiKey"] = config.pop("key", None) |
32 | 38 | ||
33 | ccxt_instance = ccxt.poloniexE(config) | 39 | ccxt_instance = ccxt.poloniexE(config) |
@@ -44,20 +50,45 @@ class Market: | |||
44 | ccxt_instance.session.request = request_wrap.__get__(ccxt_instance.session, | 50 | ccxt_instance.session.request = request_wrap.__get__(ccxt_instance.session, |
45 | ccxt_instance.session.__class__) | 51 | ccxt_instance.session.__class__) |
46 | 52 | ||
47 | return cls(ccxt_instance, args, user_id=user_id, report_path=report_path) | 53 | return cls(ccxt_instance, args, |
54 | user_id=user_id, market_id=market_id, | ||
55 | pg_config=pg_config, report_path=report_path) | ||
48 | 56 | ||
49 | def store_report(self): | 57 | def store_report(self): |
50 | self.report.merge(Portfolio.report) | 58 | self.report.merge(Portfolio.report) |
59 | date = datetime.now() | ||
60 | if self.report_path is not None: | ||
61 | self.store_file_report(date) | ||
62 | if self.pg_config is not None: | ||
63 | self.store_database_report(date) | ||
64 | |||
65 | def store_file_report(self, date): | ||
51 | try: | 66 | try: |
52 | if self.report_path is not None: | 67 | report_file = "{}/{}_{}".format(self.report_path, date.isoformat(), self.user_id) |
53 | report_file = "{}/{}_{}".format(self.report_path, datetime.now().isoformat(), self.user_id) | 68 | with open(report_file + ".json", "w") as f: |
54 | with open(report_file + ".json", "w") as f: | 69 | f.write(self.report.to_json()) |
55 | f.write(self.report.to_json()) | 70 | with open(report_file + ".log", "w") as f: |
56 | with open(report_file + ".log", "w") as f: | 71 | f.write("\n".join(map(lambda x: x[1], self.report.print_logs))) |
57 | f.write("\n".join(map(lambda x: x[1], self.report.print_logs))) | ||
58 | except Exception as e: | 72 | except Exception as e: |
59 | print("impossible to store report file: {}; {}".format(e.__class__.__name__, e)) | 73 | print("impossible to store report file: {}; {}".format(e.__class__.__name__, e)) |
60 | 74 | ||
75 | def store_database_report(self, date): | ||
76 | try: | ||
77 | report_query = 'INSERT INTO reports("date", "market_config_id", "debug") VALUES (%s, %s, %s) RETURNING id;' | ||
78 | line_query = 'INSERT INTO report_lines("date", "report_id", "type", "payload") VALUES (%s, %s, %s, %s);' | ||
79 | connection = psycopg2.connect(**self.pg_config) | ||
80 | cursor = connection.cursor() | ||
81 | cursor.execute(report_query, (date, self.market_id, self.debug)) | ||
82 | report_id = cursor.fetchone()[0] | ||
83 | for date, type_, payload in self.report.to_json_array(): | ||
84 | cursor.execute(line_query, (date, report_id, type_, payload)) | ||
85 | |||
86 | connection.commit() | ||
87 | cursor.close() | ||
88 | connection.close() | ||
89 | except Exception as e: | ||
90 | print("impossible to store report to database: {}; {}".format(e.__class__.__name__, e)) | ||
91 | |||
61 | def process(self, actions, before=False, after=False): | 92 | def process(self, actions, before=False, after=False): |
62 | try: | 93 | try: |
63 | if len(actions or []) == 0: | 94 | if len(actions or []) == 0: |
@@ -36,12 +36,22 @@ class ReportStore: | |||
36 | hash_["date"] = datetime.now() | 36 | hash_["date"] = datetime.now() |
37 | self.logs.append(hash_) | 37 | self.logs.append(hash_) |
38 | 38 | ||
39 | @staticmethod | ||
40 | def default_json_serial(obj): | ||
41 | if isinstance(obj, (datetime, date)): | ||
42 | return obj.isoformat() | ||
43 | return str(obj) | ||
44 | |||
39 | def to_json(self): | 45 | def to_json(self): |
40 | def default_json_serial(obj): | 46 | return json.dumps(self.logs, default=self.default_json_serial, indent=" ") |
41 | if isinstance(obj, (datetime, date)): | 47 | |
42 | return obj.isoformat() | 48 | def to_json_array(self): |
43 | return str(obj) | 49 | for log in (x.copy() for x in self.logs): |
44 | return json.dumps(self.logs, default=default_json_serial, indent=" ") | 50 | yield ( |
51 | log.pop("date"), | ||
52 | log.pop("type"), | ||
53 | json.dumps(log, default=self.default_json_serial, indent=" ") | ||
54 | ) | ||
45 | 55 | ||
46 | def set_verbose(self, verbose_print): | 56 | def set_verbose(self, verbose_print): |
47 | self.verbose_print = verbose_print | 57 | self.verbose_print = verbose_print |
@@ -1386,18 +1386,7 @@ class MarketTest(WebMockTestCase): | |||
1386 | self.ccxt.transfer_balance.assert_any_call("USDT", 100, "exchange", "margin") | 1386 | self.ccxt.transfer_balance.assert_any_call("USDT", 100, "exchange", "margin") |
1387 | self.ccxt.transfer_balance.assert_any_call("ETC", 5, "margin", "exchange") | 1387 | self.ccxt.transfer_balance.assert_any_call("ETC", 5, "margin", "exchange") |
1388 | 1388 | ||
1389 | def test_store_report(self): | 1389 | def test_store_file_report(self): |
1390 | |||
1391 | file_open = mock.mock_open() | ||
1392 | m = market.Market(self.ccxt, self.market_args(), user_id=1) | ||
1393 | with self.subTest(file=None),\ | ||
1394 | mock.patch.object(m, "report") as report,\ | ||
1395 | mock.patch("market.open", file_open): | ||
1396 | m.store_report() | ||
1397 | report.merge.assert_called_with(store.Portfolio.report) | ||
1398 | file_open.assert_not_called() | ||
1399 | |||
1400 | report.reset_mock() | ||
1401 | file_open = mock.mock_open() | 1390 | file_open = mock.mock_open() |
1402 | m = market.Market(self.ccxt, self.market_args(), report_path="present", user_id=1) | 1391 | m = market.Market(self.ccxt, self.market_args(), report_path="present", user_id=1) |
1403 | with self.subTest(file="present"),\ | 1392 | with self.subTest(file="present"),\ |
@@ -1405,20 +1394,16 @@ class MarketTest(WebMockTestCase): | |||
1405 | mock.patch.object(m, "report") as report,\ | 1394 | mock.patch.object(m, "report") as report,\ |
1406 | mock.patch.object(market, "datetime") as time_mock: | 1395 | mock.patch.object(market, "datetime") as time_mock: |
1407 | 1396 | ||
1408 | time_mock.now.return_value = datetime.datetime(2018, 2, 25) | ||
1409 | report.print_logs = [[time_mock.now(), "Foo"], [time_mock.now(), "Bar"]] | 1397 | report.print_logs = [[time_mock.now(), "Foo"], [time_mock.now(), "Bar"]] |
1410 | report.to_json.return_value = "json_content" | 1398 | report.to_json.return_value = "json_content" |
1411 | 1399 | ||
1412 | m.store_report() | 1400 | m.store_file_report(datetime.datetime(2018, 2, 25)) |
1413 | 1401 | ||
1414 | file_open.assert_any_call("present/2018-02-25T00:00:00_1.json", "w") | 1402 | file_open.assert_any_call("present/2018-02-25T00:00:00_1.json", "w") |
1415 | file_open.assert_any_call("present/2018-02-25T00:00:00_1.log", "w") | 1403 | file_open.assert_any_call("present/2018-02-25T00:00:00_1.log", "w") |
1416 | file_open().write.assert_any_call("json_content") | 1404 | file_open().write.assert_any_call("json_content") |
1417 | file_open().write.assert_any_call("Foo\nBar") | 1405 | file_open().write.assert_any_call("Foo\nBar") |
1418 | m.report.to_json.assert_called_once_with() | 1406 | m.report.to_json.assert_called_once_with() |
1419 | report.merge.assert_called_with(store.Portfolio.report) | ||
1420 | |||
1421 | report.reset_mock() | ||
1422 | 1407 | ||
1423 | m = market.Market(self.ccxt, self.market_args(), report_path="error", user_id=1) | 1408 | m = market.Market(self.ccxt, self.market_args(), report_path="error", user_id=1) |
1424 | with self.subTest(file="error"),\ | 1409 | with self.subTest(file="error"),\ |
@@ -1427,10 +1412,106 @@ class MarketTest(WebMockTestCase): | |||
1427 | mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock: | 1412 | mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock: |
1428 | file_open.side_effect = FileNotFoundError | 1413 | file_open.side_effect = FileNotFoundError |
1429 | 1414 | ||
1415 | m.store_file_report(datetime.datetime(2018, 2, 25)) | ||
1416 | |||
1417 | self.assertRegex(stdout_mock.getvalue(), "impossible to store report file: FileNotFoundError;") | ||
1418 | |||
1419 | @mock.patch.object(market, "psycopg2") | ||
1420 | def test_store_database_report(self, psycopg2): | ||
1421 | connect_mock = mock.Mock() | ||
1422 | cursor_mock = mock.MagicMock() | ||
1423 | |||
1424 | connect_mock.cursor.return_value = cursor_mock | ||
1425 | psycopg2.connect.return_value = connect_mock | ||
1426 | m = market.Market(self.ccxt, self.market_args(), | ||
1427 | pg_config={"config": "pg_config"}, user_id=1) | ||
1428 | cursor_mock.fetchone.return_value = [42] | ||
1429 | |||
1430 | with self.subTest(error=False),\ | ||
1431 | mock.patch.object(m, "report") as report: | ||
1432 | report.to_json_array.return_value = [ | ||
1433 | ("date1", "type1", "payload1"), | ||
1434 | ("date2", "type2", "payload2"), | ||
1435 | ] | ||
1436 | m.store_database_report(datetime.datetime(2018, 3, 24)) | ||
1437 | connect_mock.assert_has_calls([ | ||
1438 | mock.call.cursor(), | ||
1439 | mock.call.cursor().execute('INSERT INTO reports("date", "market_config_id", "debug") VALUES (%s, %s, %s) RETURNING id;', (datetime.datetime(2018, 3, 24), None, False)), | ||
1440 | mock.call.cursor().fetchone(), | ||
1441 | mock.call.cursor().execute('INSERT INTO report_lines("date", "report_id", "type", "payload") VALUES (%s, %s, %s, %s);', ('date1', 42, 'type1', 'payload1')), | ||
1442 | mock.call.cursor().execute('INSERT INTO report_lines("date", "report_id", "type", "payload") VALUES (%s, %s, %s, %s);', ('date2', 42, 'type2', 'payload2')), | ||
1443 | mock.call.commit(), | ||
1444 | mock.call.cursor().close(), | ||
1445 | mock.call.close() | ||
1446 | ]) | ||
1447 | |||
1448 | connect_mock.reset_mock() | ||
1449 | with self.subTest(error=True),\ | ||
1450 | mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock: | ||
1451 | psycopg2.connect.side_effect = Exception("Bouh") | ||
1452 | m.store_database_report(datetime.datetime(2018, 3, 24)) | ||
1453 | self.assertEqual(stdout_mock.getvalue(), "impossible to store report to database: Exception; Bouh\n") | ||
1454 | |||
1455 | def test_store_report(self): | ||
1456 | m = market.Market(self.ccxt, self.market_args(), user_id=1) | ||
1457 | with self.subTest(file=None, pg_config=None),\ | ||
1458 | mock.patch.object(m, "report") as report,\ | ||
1459 | mock.patch.object(m, "store_database_report") as db_report,\ | ||
1460 | mock.patch.object(m, "store_file_report") as file_report: | ||
1461 | m.store_report() | ||
1462 | report.merge.assert_called_with(store.Portfolio.report) | ||
1463 | |||
1464 | file_report.assert_not_called() | ||
1465 | db_report.assert_not_called() | ||
1466 | |||
1467 | report.reset_mock() | ||
1468 | m = market.Market(self.ccxt, self.market_args(), report_path="present", user_id=1) | ||
1469 | with self.subTest(file="present", pg_config=None),\ | ||
1470 | mock.patch.object(m, "report") as report,\ | ||
1471 | mock.patch.object(m, "store_file_report") as file_report,\ | ||
1472 | mock.patch.object(m, "store_database_report") as db_report,\ | ||
1473 | mock.patch.object(market, "datetime") as time_mock: | ||
1474 | |||
1475 | time_mock.now.return_value = datetime.datetime(2018, 2, 25) | ||
1476 | |||
1430 | m.store_report() | 1477 | m.store_report() |
1431 | 1478 | ||
1432 | report.merge.assert_called_with(store.Portfolio.report) | 1479 | report.merge.assert_called_with(store.Portfolio.report) |
1433 | self.assertRegex(stdout_mock.getvalue(), "impossible to store report file: FileNotFoundError;") | 1480 | file_report.assert_called_once_with(datetime.datetime(2018, 2, 25)) |
1481 | db_report.assert_not_called() | ||
1482 | |||
1483 | report.reset_mock() | ||
1484 | m = market.Market(self.ccxt, self.market_args(), pg_config="present", user_id=1) | ||
1485 | with self.subTest(file=None, pg_config="present"),\ | ||
1486 | mock.patch.object(m, "report") as report,\ | ||
1487 | mock.patch.object(m, "store_file_report") as file_report,\ | ||
1488 | mock.patch.object(m, "store_database_report") as db_report,\ | ||
1489 | mock.patch.object(market, "datetime") as time_mock: | ||
1490 | |||
1491 | time_mock.now.return_value = datetime.datetime(2018, 2, 25) | ||
1492 | |||
1493 | m.store_report() | ||
1494 | |||
1495 | report.merge.assert_called_with(store.Portfolio.report) | ||
1496 | file_report.assert_not_called() | ||
1497 | db_report.assert_called_once_with(datetime.datetime(2018, 2, 25)) | ||
1498 | |||
1499 | report.reset_mock() | ||
1500 | m = market.Market(self.ccxt, self.market_args(), | ||
1501 | pg_config="pg_config", report_path="present", user_id=1) | ||
1502 | with self.subTest(file="present", pg_config="present"),\ | ||
1503 | mock.patch.object(m, "report") as report,\ | ||
1504 | mock.patch.object(m, "store_file_report") as file_report,\ | ||
1505 | mock.patch.object(m, "store_database_report") as db_report,\ | ||
1506 | mock.patch.object(market, "datetime") as time_mock: | ||
1507 | |||
1508 | time_mock.now.return_value = datetime.datetime(2018, 2, 25) | ||
1509 | |||
1510 | m.store_report() | ||
1511 | |||
1512 | report.merge.assert_called_with(store.Portfolio.report) | ||
1513 | file_report.assert_called_once_with(datetime.datetime(2018, 2, 25)) | ||
1514 | db_report.assert_called_once_with(datetime.datetime(2018, 2, 25)) | ||
1434 | 1515 | ||
1435 | def test_print_orders(self): | 1516 | def test_print_orders(self): |
1436 | m = market.Market(self.ccxt, self.market_args()) | 1517 | m = market.Market(self.ccxt, self.market_args()) |
@@ -3050,6 +3131,14 @@ class ReportStoreTest(WebMockTestCase): | |||
3050 | report_store.print_log(portfolio.Amount("BTC", 1)) | 3131 | report_store.print_log(portfolio.Amount("BTC", 1)) |
3051 | self.assertEqual(stdout_mock.getvalue(), "") | 3132 | self.assertEqual(stdout_mock.getvalue(), "") |
3052 | 3133 | ||
3134 | def test_default_json_serial(self): | ||
3135 | report_store = market.ReportStore(self.m) | ||
3136 | |||
3137 | self.assertEqual("2018-02-24T00:00:00", | ||
3138 | report_store.default_json_serial(portfolio.datetime(2018, 2, 24))) | ||
3139 | self.assertEqual("1.00000000 BTC", | ||
3140 | report_store.default_json_serial(portfolio.Amount("BTC", 1))) | ||
3141 | |||
3053 | def test_to_json(self): | 3142 | def test_to_json(self): |
3054 | report_store = market.ReportStore(self.m) | 3143 | report_store = market.ReportStore(self.m) |
3055 | report_store.logs.append({"foo": "bar"}) | 3144 | report_store.logs.append({"foo": "bar"}) |
@@ -3059,6 +3148,20 @@ class ReportStoreTest(WebMockTestCase): | |||
3059 | report_store.logs.append({"amount": portfolio.Amount("BTC", 1)}) | 3148 | report_store.logs.append({"amount": portfolio.Amount("BTC", 1)}) |
3060 | self.assertEqual('[\n {\n "foo": "bar"\n },\n {\n "date": "2018-02-24T00:00:00"\n },\n {\n "amount": "1.00000000 BTC"\n }\n]', report_store.to_json()) | 3149 | self.assertEqual('[\n {\n "foo": "bar"\n },\n {\n "date": "2018-02-24T00:00:00"\n },\n {\n "amount": "1.00000000 BTC"\n }\n]', report_store.to_json()) |
3061 | 3150 | ||
3151 | def test_to_json_array(self): | ||
3152 | report_store = market.ReportStore(self.m) | ||
3153 | report_store.logs.append({ | ||
3154 | "date": "date1", "type": "type1", "foo": "bar", "bla": "bla" | ||
3155 | }) | ||
3156 | report_store.logs.append({ | ||
3157 | "date": "date2", "type": "type2", "foo": "bar", "bla": "bla" | ||
3158 | }) | ||
3159 | logs = list(report_store.to_json_array()) | ||
3160 | |||
3161 | self.assertEqual(2, len(logs)) | ||
3162 | self.assertEqual(("date1", "type1", '{\n "foo": "bar",\n "bla": "bla"\n}'), logs[0]) | ||
3163 | self.assertEqual(("date2", "type2", '{\n "foo": "bar",\n "bla": "bla"\n}'), logs[1]) | ||
3164 | |||
3062 | @mock.patch.object(market.ReportStore, "print_log") | 3165 | @mock.patch.object(market.ReportStore, "print_log") |
3063 | @mock.patch.object(market.ReportStore, "add_log") | 3166 | @mock.patch.object(market.ReportStore, "add_log") |
3064 | def test_log_stage(self, add_log, print_log): | 3167 | def test_log_stage(self, add_log, print_log): |
@@ -3552,7 +3655,7 @@ class MainTest(WebMockTestCase): | |||
3552 | mock.patch("main.parse_config") as main_parse_config: | 3655 | mock.patch("main.parse_config") as main_parse_config: |
3553 | with self.subTest(debug=False): | 3656 | with self.subTest(debug=False): |
3554 | main_parse_config.return_value = ["pg_config", "report_path"] | 3657 | main_parse_config.return_value = ["pg_config", "report_path"] |
3555 | main_fetch_markets.return_value = [({"key": "market_config"},)] | 3658 | main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)] |
3556 | m = main.get_user_market("config_path.ini", 1) | 3659 | m = main.get_user_market("config_path.ini", 1) |
3557 | 3660 | ||
3558 | self.assertIsInstance(m, market.Market) | 3661 | self.assertIsInstance(m, market.Market) |
@@ -3560,7 +3663,7 @@ class MainTest(WebMockTestCase): | |||
3560 | 3663 | ||
3561 | with self.subTest(debug=True): | 3664 | with self.subTest(debug=True): |
3562 | main_parse_config.return_value = ["pg_config", "report_path"] | 3665 | main_parse_config.return_value = ["pg_config", "report_path"] |
3563 | main_fetch_markets.return_value = [({"key": "market_config"},)] | 3666 | main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)] |
3564 | m = main.get_user_market("config_path.ini", 1, debug=True) | 3667 | m = main.get_user_market("config_path.ini", 1, debug=True) |
3565 | 3668 | ||
3566 | self.assertIsInstance(m, market.Market) | 3669 | self.assertIsInstance(m, market.Market) |
@@ -3579,16 +3682,16 @@ class MainTest(WebMockTestCase): | |||
3579 | args_mock.after = "after" | 3682 | args_mock.after = "after" |
3580 | self.assertEqual("", stdout_mock.getvalue()) | 3683 | self.assertEqual("", stdout_mock.getvalue()) |
3581 | 3684 | ||
3582 | main.process("config", 1, "report_path", args_mock) | 3685 | main.process(3, "config", 1, "report_path", args_mock, "pg_config") |
3583 | 3686 | ||
3584 | market_mock.from_config.assert_has_calls([ | 3687 | market_mock.from_config.assert_has_calls([ |
3585 | mock.call("config", args_mock, user_id=1, report_path="report_path"), | 3688 | mock.call("config", args_mock, pg_config="pg_config", market_id=3, user_id=1, report_path="report_path"), |
3586 | mock.call().process("action", before="before", after="after"), | 3689 | mock.call().process("action", before="before", after="after"), |
3587 | ]) | 3690 | ]) |
3588 | 3691 | ||
3589 | with self.subTest(exception=True): | 3692 | with self.subTest(exception=True): |
3590 | market_mock.from_config.side_effect = Exception("boo") | 3693 | market_mock.from_config.side_effect = Exception("boo") |
3591 | main.process("config", 1, "report_path", args_mock) | 3694 | main.process(3, "config", 1, "report_path", args_mock, "pg_config") |
3592 | self.assertEqual("Exception: boo\n", stdout_mock.getvalue()) | 3695 | self.assertEqual("Exception: boo\n", stdout_mock.getvalue()) |
3593 | 3696 | ||
3594 | def test_main(self): | 3697 | def test_main(self): |
@@ -3606,7 +3709,7 @@ class MainTest(WebMockTestCase): | |||
3606 | 3709 | ||
3607 | parse_config.return_value = ["pg_config", "report_path"] | 3710 | parse_config.return_value = ["pg_config", "report_path"] |
3608 | 3711 | ||
3609 | fetch_markets.return_value = [["config1", 1], ["config2", 2]] | 3712 | fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] |
3610 | 3713 | ||
3611 | main.main(["Foo", "Bar"]) | 3714 | main.main(["Foo", "Bar"]) |
3612 | 3715 | ||
@@ -3616,8 +3719,8 @@ class MainTest(WebMockTestCase): | |||
3616 | 3719 | ||
3617 | self.assertEqual(2, process.call_count) | 3720 | self.assertEqual(2, process.call_count) |
3618 | process.assert_has_calls([ | 3721 | process.assert_has_calls([ |
3619 | mock.call("config1", 1, "report_path", args_mock), | 3722 | mock.call(3, "config1", 1, "report_path", args_mock, "pg_config"), |
3620 | mock.call("config2", 2, "report_path", args_mock), | 3723 | mock.call(1, "config2", 2, "report_path", args_mock, "pg_config"), |
3621 | ]) | 3724 | ]) |
3622 | with self.subTest(parallel=True): | 3725 | with self.subTest(parallel=True): |
3623 | with mock.patch("main.parse_args") as parse_args,\ | 3726 | with mock.patch("main.parse_args") as parse_args,\ |
@@ -3634,7 +3737,7 @@ class MainTest(WebMockTestCase): | |||
3634 | 3737 | ||
3635 | parse_config.return_value = ["pg_config", "report_path"] | 3738 | parse_config.return_value = ["pg_config", "report_path"] |
3636 | 3739 | ||
3637 | fetch_markets.return_value = [["config1", 1], ["config2", 2]] | 3740 | fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] |
3638 | 3741 | ||
3639 | main.main(["Foo", "Bar"]) | 3742 | main.main(["Foo", "Bar"]) |
3640 | 3743 | ||
@@ -3646,9 +3749,9 @@ class MainTest(WebMockTestCase): | |||
3646 | self.assertEqual(2, process.call_count) | 3749 | self.assertEqual(2, process.call_count) |
3647 | process.assert_has_calls([ | 3750 | process.assert_has_calls([ |
3648 | mock.call.__bool__(), | 3751 | mock.call.__bool__(), |
3649 | mock.call("config1", 1, "report_path", args_mock), | 3752 | mock.call(3, "config1", 1, "report_path", args_mock, "pg_config"), |
3650 | mock.call.__bool__(), | 3753 | mock.call.__bool__(), |
3651 | mock.call("config2", 2, "report_path", args_mock), | 3754 | mock.call(1, "config2", 2, "report_path", args_mock, "pg_config"), |
3652 | ]) | 3755 | ]) |
3653 | 3756 | ||
3654 | @mock.patch.object(main.sys, "exit") | 3757 | @mock.patch.object(main.sys, "exit") |
@@ -3734,7 +3837,7 @@ class MainTest(WebMockTestCase): | |||
3734 | rows = list(main.fetch_markets({"foo": "bar"}, None)) | 3837 | rows = list(main.fetch_markets({"foo": "bar"}, None)) |
3735 | 3838 | ||
3736 | psycopg2.connect.assert_called_once_with(foo="bar") | 3839 | psycopg2.connect.assert_called_once_with(foo="bar") |
3737 | cursor_mock.execute.assert_called_once_with("SELECT config,user_id FROM market_configs") | 3840 | cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs") |
3738 | 3841 | ||
3739 | self.assertEqual(["row_1", "row_2"], rows) | 3842 | self.assertEqual(["row_1", "row_2"], rows) |
3740 | 3843 | ||
@@ -3744,7 +3847,7 @@ class MainTest(WebMockTestCase): | |||
3744 | rows = list(main.fetch_markets({"foo": "bar"}, 1)) | 3847 | rows = list(main.fetch_markets({"foo": "bar"}, 1)) |
3745 | 3848 | ||
3746 | psycopg2.connect.assert_called_once_with(foo="bar") | 3849 | psycopg2.connect.assert_called_once_with(foo="bar") |
3747 | cursor_mock.execute.assert_called_once_with("SELECT config,user_id FROM market_configs WHERE user_id = %s", 1) | 3850 | cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs WHERE user_id = %s", 1) |
3748 | 3851 | ||
3749 | self.assertEqual(["row_1", "row_2"], rows) | 3852 | self.assertEqual(["row_1", "row_2"], rows) |
3750 | 3853 | ||