From 516a2517aa428596199e56cc105c7b0132064ade Mon Sep 17 00:00:00 2001 From: =?utf8?q?Isma=C3=ABl=20Bouya?= Date: Mon, 26 Feb 2018 11:13:38 +0100 Subject: [PATCH] Add user and action for main actions --- helper.py | 31 +++++++++++++++++++++++-------- main.py | 4 ++-- test.py | 52 ++++++++++++++++++++++++++++++++++++++++++++-------- 3 files changed, 69 insertions(+), 18 deletions(-) diff --git a/helper.py b/helper.py index 4b9ce0d..6d28c3f 100644 --- a/helper.py +++ b/helper.py @@ -24,6 +24,11 @@ def main_parse_args(argv): parser.add_argument("--debug", default=False, action='store_const', const=True, help="Run in debug mode") + parser.add_argument("--user", + default=None, required=False, help="Only run for that user") + parser.add_argument("--action", + default=None, required=False, + help="Do a different action than trading") args = parser.parse_args(argv) @@ -51,21 +56,31 @@ def main_parse_config(config_file): return [config["postgresql"], report_path] -def main_fetch_markets(pg_config): +def main_fetch_markets(pg_config, user): connection = psycopg2.connect(**pg_config) cursor = connection.cursor() - cursor.execute("SELECT config,user_id FROM market_configs") + if user is None: + cursor.execute("SELECT config,user_id FROM market_configs") + else: + cursor.execute("SELECT config,user_id FROM market_configs WHERE user_id = %s", user) for row in cursor: yield row -def main_process_market(user_market, before=False, after=False): - if before: - process_sell_all__1_all_sell(user_market) - if after: - portfolio.Portfolio.wait_for_recent(user_market) - process_sell_all__2_all_buy(user_market) +def main_process_market(user_market, action, before=False, after=False): + if action is None: + if before: + process_sell_all__1_all_sell(user_market) + if after: + portfolio.Portfolio.wait_for_recent(user_market) + process_sell_all__2_all_buy(user_market) + elif action == "print_balances": + print_balances(user_market) + elif action == "print_orders": + print_orders(user_market) + else: + raise NotImplementedError("Unknown action {}".format(action)) def main_store_report(report_path, user_id, user_market): try: diff --git a/main.py b/main.py index e7cdcf0..3cb7f4a 100644 --- a/main.py +++ b/main.py @@ -5,11 +5,11 @@ args = helper.main_parse_args(sys.argv[1:]) pg_config, report_path = helper.main_parse_config(args.config) -for market_config, user_id in helper.main_fetch_markets(pg_config): +for market_config, user_id in helper.main_fetch_markets(pg_config, args.user): try: market_config["apiKey"] = market_config.pop("key") user_market = market.Market.from_config(market_config, debug=args.debug) - helper.main_process_market(user_market, before=args.before, after=args.after) + helper.main_process_market(user_market, args.action, before=args.before, after=args.after) except Exception as e: print("{}: {}".format(e.__class__.__name__, e)) finally: diff --git a/test.py b/test.py index a4ec8d2..4ed0477 100644 --- a/test.py +++ b/test.py @@ -2562,7 +2562,7 @@ class HelperTest(WebMockTestCase): @mock.patch("portfolio.Portfolio.wait_for_recent") def test_main_process_market(self, wait, buy, sell): with self.subTest(before=False, after=False): - helper.main_process_market("user") + helper.main_process_market("user", None) wait.assert_not_called() buy.assert_not_called() @@ -2572,7 +2572,7 @@ class HelperTest(WebMockTestCase): wait.reset_mock() sell.reset_mock() with self.subTest(before=True, after=False): - helper.main_process_market("user", before=True) + helper.main_process_market("user", None, before=True) wait.assert_not_called() buy.assert_not_called() @@ -2582,7 +2582,7 @@ class HelperTest(WebMockTestCase): wait.reset_mock() sell.reset_mock() with self.subTest(before=False, after=True): - helper.main_process_market("user", after=True) + helper.main_process_market("user", None, after=True) wait.assert_called_once_with("user") buy.assert_called_once_with("user") @@ -2592,12 +2592,37 @@ class HelperTest(WebMockTestCase): wait.reset_mock() sell.reset_mock() with self.subTest(before=True, after=True): - helper.main_process_market("user", before=True, after=True) + helper.main_process_market("user", None, before=True, after=True) wait.assert_called_once_with("user") buy.assert_called_once_with("user") sell.assert_called_once_with("user") + buy.reset_mock() + wait.reset_mock() + sell.reset_mock() + with self.subTest(action="print_balances"),\ + mock.patch("helper.print_balances") as print_balances: + helper.main_process_market("user", "print_balances") + + buy.assert_not_called() + wait.assert_not_called() + sell.assert_not_called() + print_balances.assert_called_once_with("user") + + with self.subTest(action="print_orders"),\ + mock.patch("helper.print_orders") as print_orders: + helper.main_process_market("user", "print_orders") + + buy.assert_not_called() + wait.assert_not_called() + sell.assert_not_called() + print_orders.assert_called_once_with("user") + + with self.subTest(action="unknown"),\ + self.assertRaises(NotImplementedError): + helper.main_process_market("user", "unknown") + @mock.patch.object(helper, "psycopg2") def test_fetch_markets(self, psycopg2): connect_mock = mock.Mock() @@ -2607,12 +2632,23 @@ class HelperTest(WebMockTestCase): connect_mock.cursor.return_value = cursor_mock psycopg2.connect.return_value = connect_mock - rows = list(helper.main_fetch_markets({"foo": "bar"})) + with self.subTest(user=None): + rows = list(helper.main_fetch_markets({"foo": "bar"}, None)) + + psycopg2.connect.assert_called_once_with(foo="bar") + cursor_mock.execute.assert_called_once_with("SELECT config,user_id FROM market_configs") + + self.assertEqual(["row_1", "row_2"], rows) + + psycopg2.connect.reset_mock() + cursor_mock.execute.reset_mock() + with self.subTest(user=1): + rows = list(helper.main_fetch_markets({"foo": "bar"}, 1)) - psycopg2.connect.assert_called_once_with(foo="bar") - cursor_mock.execute.assert_called_once_with("SELECT config,user_id FROM market_configs") + psycopg2.connect.assert_called_once_with(foo="bar") + cursor_mock.execute.assert_called_once_with("SELECT config,user_id FROM market_configs WHERE user_id = %s", 1) - self.assertEqual(["row_1", "row_2"], rows) + self.assertEqual(["row_1", "row_2"], rows) @mock.patch.object(helper.sys, "exit") def test_main_parse_args(self, exit): -- 2.41.0