diff options
Diffstat (limited to 'tests/test_main.py')
-rw-r--r-- | tests/test_main.py | 290 |
1 files changed, 290 insertions, 0 deletions
diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..6396c07 --- /dev/null +++ b/tests/test_main.py | |||
@@ -0,0 +1,290 @@ | |||
1 | from .helper import * | ||
2 | import main, market | ||
3 | |||
4 | @unittest.skipUnless("unit" in limits, "Unit skipped") | ||
5 | class MainTest(WebMockTestCase): | ||
6 | def test_make_order(self): | ||
7 | self.m.get_ticker.return_value = { | ||
8 | "inverted": False, | ||
9 | "average": D("0.1"), | ||
10 | "bid": D("0.09"), | ||
11 | "ask": D("0.11"), | ||
12 | } | ||
13 | |||
14 | with self.subTest(description="nominal case"): | ||
15 | main.make_order(self.m, 10, "ETH") | ||
16 | |||
17 | self.m.report.log_stage.assert_has_calls([ | ||
18 | mock.call("make_order_begin"), | ||
19 | mock.call("make_order_end"), | ||
20 | ]) | ||
21 | self.m.balances.fetch_balances.assert_has_calls([ | ||
22 | mock.call(tag="make_order_begin"), | ||
23 | mock.call(tag="make_order_end"), | ||
24 | ]) | ||
25 | self.m.trades.all.append.assert_called_once() | ||
26 | trade = self.m.trades.all.append.mock_calls[0][1][0] | ||
27 | self.assertEqual(False, trade.orders[0].close_if_possible) | ||
28 | self.assertEqual(0, trade.value_from) | ||
29 | self.assertEqual("ETH", trade.currency) | ||
30 | self.assertEqual("BTC", trade.base_currency) | ||
31 | self.m.report.log_orders.assert_called_once_with([trade.orders[0]], None, "average") | ||
32 | self.m.trades.run_orders.assert_called_once_with() | ||
33 | self.m.follow_orders.assert_called_once_with() | ||
34 | |||
35 | order = trade.orders[0] | ||
36 | self.assertEqual(D("0.10"), order.rate) | ||
37 | |||
38 | self.m.reset_mock() | ||
39 | with self.subTest(compute_value="default"): | ||
40 | main.make_order(self.m, 10, "ETH", action="dispose", | ||
41 | compute_value="ask") | ||
42 | |||
43 | trade = self.m.trades.all.append.mock_calls[0][1][0] | ||
44 | order = trade.orders[0] | ||
45 | self.assertEqual(D("0.11"), order.rate) | ||
46 | |||
47 | self.m.reset_mock() | ||
48 | with self.subTest(follow=False): | ||
49 | result = main.make_order(self.m, 10, "ETH", follow=False) | ||
50 | |||
51 | self.m.report.log_stage.assert_has_calls([ | ||
52 | mock.call("make_order_begin"), | ||
53 | mock.call("make_order_end_not_followed"), | ||
54 | ]) | ||
55 | self.m.balances.fetch_balances.assert_called_once_with(tag="make_order_begin") | ||
56 | |||
57 | self.m.trades.all.append.assert_called_once() | ||
58 | trade = self.m.trades.all.append.mock_calls[0][1][0] | ||
59 | self.assertEqual(0, trade.value_from) | ||
60 | self.assertEqual("ETH", trade.currency) | ||
61 | self.assertEqual("BTC", trade.base_currency) | ||
62 | self.m.report.log_orders.assert_called_once_with([trade.orders[0]], None, "average") | ||
63 | self.m.trades.run_orders.assert_called_once_with() | ||
64 | self.m.follow_orders.assert_not_called() | ||
65 | self.assertEqual(trade.orders[0], result) | ||
66 | |||
67 | self.m.reset_mock() | ||
68 | with self.subTest(base_currency="USDT"): | ||
69 | main.make_order(self.m, 1, "BTC", base_currency="USDT") | ||
70 | |||
71 | trade = self.m.trades.all.append.mock_calls[0][1][0] | ||
72 | self.assertEqual("BTC", trade.currency) | ||
73 | self.assertEqual("USDT", trade.base_currency) | ||
74 | |||
75 | self.m.reset_mock() | ||
76 | with self.subTest(close_if_possible=True): | ||
77 | main.make_order(self.m, 10, "ETH", close_if_possible=True) | ||
78 | |||
79 | trade = self.m.trades.all.append.mock_calls[0][1][0] | ||
80 | self.assertEqual(True, trade.orders[0].close_if_possible) | ||
81 | |||
82 | self.m.reset_mock() | ||
83 | with self.subTest(action="dispose"): | ||
84 | main.make_order(self.m, 10, "ETH", action="dispose") | ||
85 | |||
86 | trade = self.m.trades.all.append.mock_calls[0][1][0] | ||
87 | self.assertEqual(0, trade.value_to) | ||
88 | self.assertEqual(1, trade.value_from.value) | ||
89 | self.assertEqual("ETH", trade.currency) | ||
90 | self.assertEqual("BTC", trade.base_currency) | ||
91 | |||
92 | self.m.reset_mock() | ||
93 | with self.subTest(compute_value="default"): | ||
94 | main.make_order(self.m, 10, "ETH", action="dispose", | ||
95 | compute_value="bid") | ||
96 | |||
97 | trade = self.m.trades.all.append.mock_calls[0][1][0] | ||
98 | self.assertEqual(D("0.9"), trade.value_from.value) | ||
99 | |||
100 | def test_get_user_market(self): | ||
101 | with mock.patch("main.fetch_markets") as main_fetch_markets,\ | ||
102 | mock.patch("main.parse_args") as main_parse_args,\ | ||
103 | mock.patch("main.parse_config") as main_parse_config: | ||
104 | with self.subTest(debug=False): | ||
105 | main_parse_args.return_value = self.market_args() | ||
106 | main_parse_config.return_value = "pg_config" | ||
107 | main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)] | ||
108 | m = main.get_user_market("config_path.ini", 1) | ||
109 | |||
110 | self.assertIsInstance(m, market.Market) | ||
111 | self.assertFalse(m.debug) | ||
112 | main_parse_args.assert_called_once_with(["--config", "config_path.ini"]) | ||
113 | |||
114 | main_parse_args.reset_mock() | ||
115 | with self.subTest(debug=True): | ||
116 | main_parse_args.return_value = self.market_args(debug=True) | ||
117 | main_parse_config.return_value = "pg_config" | ||
118 | main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)] | ||
119 | m = main.get_user_market("config_path.ini", 1, debug=True) | ||
120 | |||
121 | self.assertIsInstance(m, market.Market) | ||
122 | self.assertTrue(m.debug) | ||
123 | main_parse_args.assert_called_once_with(["--config", "config_path.ini", "--debug"]) | ||
124 | |||
125 | def test_process(self): | ||
126 | with mock.patch("market.Market") as market_mock,\ | ||
127 | mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock: | ||
128 | |||
129 | args_mock = mock.Mock() | ||
130 | args_mock.action = "action" | ||
131 | args_mock.config = "config" | ||
132 | args_mock.user = "user" | ||
133 | args_mock.debug = "debug" | ||
134 | args_mock.before = "before" | ||
135 | args_mock.after = "after" | ||
136 | self.assertEqual("", stdout_mock.getvalue()) | ||
137 | |||
138 | main.process("config", 3, 1, args_mock, "pg_config") | ||
139 | |||
140 | market_mock.from_config.assert_has_calls([ | ||
141 | mock.call("config", args_mock, pg_config="pg_config", market_id=3, user_id=1), | ||
142 | mock.call().process("action", before="before", after="after"), | ||
143 | ]) | ||
144 | |||
145 | with self.subTest(exception=True): | ||
146 | market_mock.from_config.side_effect = Exception("boo") | ||
147 | main.process(3, "config", 1, args_mock, "pg_config") | ||
148 | self.assertEqual("Exception: boo\n", stdout_mock.getvalue()) | ||
149 | |||
150 | def test_main(self): | ||
151 | with self.subTest(parallel=False): | ||
152 | with mock.patch("main.parse_args") as parse_args,\ | ||
153 | mock.patch("main.parse_config") as parse_config,\ | ||
154 | mock.patch("main.fetch_markets") as fetch_markets,\ | ||
155 | mock.patch("main.process") as process: | ||
156 | |||
157 | args_mock = mock.Mock() | ||
158 | args_mock.parallel = False | ||
159 | args_mock.user = "user" | ||
160 | parse_args.return_value = args_mock | ||
161 | |||
162 | parse_config.return_value = "pg_config" | ||
163 | |||
164 | fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] | ||
165 | |||
166 | main.main(["Foo", "Bar"]) | ||
167 | |||
168 | parse_args.assert_called_with(["Foo", "Bar"]) | ||
169 | parse_config.assert_called_with(args_mock) | ||
170 | fetch_markets.assert_called_with("pg_config", "user") | ||
171 | |||
172 | self.assertEqual(2, process.call_count) | ||
173 | process.assert_has_calls([ | ||
174 | mock.call("config1", 3, 1, args_mock, "pg_config"), | ||
175 | mock.call("config2", 1, 2, args_mock, "pg_config"), | ||
176 | ]) | ||
177 | with self.subTest(parallel=True): | ||
178 | with mock.patch("main.parse_args") as parse_args,\ | ||
179 | mock.patch("main.parse_config") as parse_config,\ | ||
180 | mock.patch("main.fetch_markets") as fetch_markets,\ | ||
181 | mock.patch("main.process") as process,\ | ||
182 | mock.patch("store.Portfolio.start_worker") as start: | ||
183 | |||
184 | args_mock = mock.Mock() | ||
185 | args_mock.parallel = True | ||
186 | args_mock.user = "user" | ||
187 | parse_args.return_value = args_mock | ||
188 | |||
189 | parse_config.return_value = "pg_config" | ||
190 | |||
191 | fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] | ||
192 | |||
193 | main.main(["Foo", "Bar"]) | ||
194 | |||
195 | parse_args.assert_called_with(["Foo", "Bar"]) | ||
196 | parse_config.assert_called_with(args_mock) | ||
197 | fetch_markets.assert_called_with("pg_config", "user") | ||
198 | |||
199 | start.assert_called_once_with() | ||
200 | self.assertEqual(2, process.call_count) | ||
201 | process.assert_has_calls([ | ||
202 | mock.call.__bool__(), | ||
203 | mock.call("config1", 3, 1, args_mock, "pg_config"), | ||
204 | mock.call.__bool__(), | ||
205 | mock.call("config2", 1, 2, args_mock, "pg_config"), | ||
206 | ]) | ||
207 | |||
208 | @mock.patch.object(main.sys, "exit") | ||
209 | @mock.patch("main.os") | ||
210 | def test_parse_config(self, os, exit): | ||
211 | with self.subTest(report_path=None): | ||
212 | args = main.configargparse.Namespace(**{ | ||
213 | "db_host": "host", | ||
214 | "db_port": "port", | ||
215 | "db_user": "user", | ||
216 | "db_password": "password", | ||
217 | "db_database": "database", | ||
218 | "report_path": None, | ||
219 | }) | ||
220 | |||
221 | result = main.parse_config(args) | ||
222 | self.assertEqual({ "host": "host", "port": "port", "user": | ||
223 | "user", "password": "password", "database": "database" | ||
224 | }, result) | ||
225 | with self.assertRaises(AttributeError): | ||
226 | args.db_password | ||
227 | |||
228 | with self.subTest(report_path="present"): | ||
229 | args = main.configargparse.Namespace(**{ | ||
230 | "db_host": "host", | ||
231 | "db_port": "port", | ||
232 | "db_user": "user", | ||
233 | "db_password": "password", | ||
234 | "db_database": "database", | ||
235 | "report_path": "report_path", | ||
236 | }) | ||
237 | |||
238 | os.path.exists.return_value = False | ||
239 | |||
240 | result = main.parse_config(args) | ||
241 | |||
242 | os.path.exists.assert_called_once_with("report_path") | ||
243 | os.makedirs.assert_called_once_with("report_path") | ||
244 | |||
245 | def test_parse_args(self): | ||
246 | with self.subTest(config="config.ini"): | ||
247 | args = main.parse_args([]) | ||
248 | self.assertEqual("config.ini", args.config) | ||
249 | self.assertFalse(args.before) | ||
250 | self.assertFalse(args.after) | ||
251 | self.assertFalse(args.debug) | ||
252 | |||
253 | args = main.parse_args(["--before", "--after", "--debug"]) | ||
254 | self.assertTrue(args.before) | ||
255 | self.assertTrue(args.after) | ||
256 | self.assertTrue(args.debug) | ||
257 | |||
258 | with self.subTest(config="inexistant"), \ | ||
259 | self.assertRaises(SystemExit), \ | ||
260 | mock.patch('sys.stderr', new_callable=StringIO) as stdout_mock: | ||
261 | args = main.parse_args(["--config", "foo.bar"]) | ||
262 | |||
263 | @mock.patch.object(main, "psycopg2") | ||
264 | def test_fetch_markets(self, psycopg2): | ||
265 | connect_mock = mock.Mock() | ||
266 | cursor_mock = mock.MagicMock() | ||
267 | cursor_mock.__iter__.return_value = ["row_1", "row_2"] | ||
268 | |||
269 | connect_mock.cursor.return_value = cursor_mock | ||
270 | psycopg2.connect.return_value = connect_mock | ||
271 | |||
272 | with self.subTest(user=None): | ||
273 | rows = list(main.fetch_markets({"foo": "bar"}, None)) | ||
274 | |||
275 | psycopg2.connect.assert_called_once_with(foo="bar") | ||
276 | cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs") | ||
277 | |||
278 | self.assertEqual(["row_1", "row_2"], rows) | ||
279 | |||
280 | psycopg2.connect.reset_mock() | ||
281 | cursor_mock.execute.reset_mock() | ||
282 | with self.subTest(user=1): | ||
283 | rows = list(main.fetch_markets({"foo": "bar"}, 1)) | ||
284 | |||
285 | psycopg2.connect.assert_called_once_with(foo="bar") | ||
286 | cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs WHERE user_id = %s", 1) | ||
287 | |||
288 | self.assertEqual(["row_1", "row_2"], rows) | ||
289 | |||
290 | |||