diff options
author | Ismaël Bouya <ismael.bouya@normalesup.org> | 2018-03-24 16:07:11 +0100 |
---|---|---|
committer | Ismaël Bouya <ismael.bouya@normalesup.org> | 2018-03-24 16:07:11 +0100 |
commit | 45fffd4963005a1f3957868f9ddb1aa7ec66c0e3 (patch) | |
tree | 80b7ac28d5d2ed26a7c52ae02d3dc251396b92fb | |
parent | 472787b6360221588423d03fe3e73d92c09a7c9d (diff) | |
parent | 445b4a7712fb7fe45e17b6b76356dd3be42dd900 (diff) | |
download | Trader-45fffd4963005a1f3957868f9ddb1aa7ec66c0e3.tar.gz Trader-45fffd4963005a1f3957868f9ddb1aa7ec66c0e3.tar.zst Trader-45fffd4963005a1f3957868f9ddb1aa7ec66c0e3.zip |
Merge branch 'retry_timeout' into dev
-rw-r--r-- | ccxt_wrapper.py | 45 | ||||
-rw-r--r-- | market.py | 12 | ||||
-rw-r--r-- | test.py | 70 |
3 files changed, 109 insertions, 18 deletions
diff --git a/ccxt_wrapper.py b/ccxt_wrapper.py index d37c306..4ed37d9 100644 --- a/ccxt_wrapper.py +++ b/ccxt_wrapper.py | |||
@@ -1,12 +1,57 @@ | |||
1 | from ccxt import * | 1 | from ccxt import * |
2 | import decimal | 2 | import decimal |
3 | import time | 3 | import time |
4 | from retry.api import retry_call | ||
5 | import re | ||
4 | 6 | ||
5 | def _cw_exchange_sum(self, *args): | 7 | def _cw_exchange_sum(self, *args): |
6 | return sum([arg for arg in args if isinstance(arg, (float, int, decimal.Decimal))]) | 8 | return sum([arg for arg in args if isinstance(arg, (float, int, decimal.Decimal))]) |
7 | Exchange.sum = _cw_exchange_sum | 9 | Exchange.sum = _cw_exchange_sum |
8 | 10 | ||
9 | class poloniexE(poloniex): | 11 | class poloniexE(poloniex): |
12 | RETRIABLE_CALLS = [ | ||
13 | re.compile(r"^return"), | ||
14 | re.compile(r"^cancel"), | ||
15 | re.compile(r"^closeMarginPosition$"), | ||
16 | re.compile(r"^getMarginPosition$"), | ||
17 | ] | ||
18 | |||
19 | def request(self, path, api='public', method='GET', params={}, headers=None, body=None): | ||
20 | """ | ||
21 | Wrapped to allow retry of non-posting requests" | ||
22 | """ | ||
23 | |||
24 | origin_request = super(poloniexE, self).request | ||
25 | kwargs = { | ||
26 | "api": api, | ||
27 | "method": method, | ||
28 | "params": params, | ||
29 | "headers": headers, | ||
30 | "body": body | ||
31 | } | ||
32 | |||
33 | retriable = any(re.match(call, path) for call in self.RETRIABLE_CALLS) | ||
34 | if api == "public" or method == "GET" or retriable: | ||
35 | return retry_call(origin_request, fargs=[path], fkwargs=kwargs, | ||
36 | tries=10, delay=1, exceptions=(RequestTimeout,)) | ||
37 | else: | ||
38 | return origin_request(path, **kwargs) | ||
39 | |||
40 | def __init__(self, *args, **kwargs): | ||
41 | super(poloniexE, self).__init__(*args, **kwargs) | ||
42 | |||
43 | # For requests logging | ||
44 | self.session.origin_request = self.session.request | ||
45 | self.session._parent = self | ||
46 | |||
47 | def request_wrap(self, *args, **kwargs): | ||
48 | r = self.origin_request(*args, **kwargs) | ||
49 | self._parent._market.report.log_http_request(args[0], | ||
50 | args[1], kwargs["data"], kwargs["headers"], r) | ||
51 | return r | ||
52 | self.session.request = request_wrap.__get__(self.session, | ||
53 | self.session.__class__) | ||
54 | |||
10 | @staticmethod | 55 | @staticmethod |
11 | def nanoseconds(): | 56 | def nanoseconds(): |
12 | return int(time.time() * 1000000000) | 57 | return int(time.time() * 1000000000) |
@@ -33,18 +33,6 @@ class Market: | |||
33 | 33 | ||
34 | ccxt_instance = ccxt.poloniexE(config) | 34 | ccxt_instance = ccxt.poloniexE(config) |
35 | 35 | ||
36 | # For requests logging | ||
37 | ccxt_instance.session.origin_request = ccxt_instance.session.request | ||
38 | ccxt_instance.session._parent = ccxt_instance | ||
39 | |||
40 | def request_wrap(self, *args, **kwargs): | ||
41 | r = self.origin_request(*args, **kwargs) | ||
42 | self._parent._market.report.log_http_request(args[0], | ||
43 | args[1], kwargs["data"], kwargs["headers"], r) | ||
44 | return r | ||
45 | ccxt_instance.session.request = request_wrap.__get__(ccxt_instance.session, | ||
46 | ccxt_instance.session.__class__) | ||
47 | |||
48 | return cls(ccxt_instance, args, **kwargs) | 36 | return cls(ccxt_instance, args, **kwargs) |
49 | 37 | ||
50 | def store_report(self): | 38 | def store_report(self): |
@@ -70,6 +70,18 @@ class poloniexETest(unittest.TestCase): | |||
70 | self.wm.stop() | 70 | self.wm.stop() |
71 | super(poloniexETest, self).tearDown() | 71 | super(poloniexETest, self).tearDown() |
72 | 72 | ||
73 | def test__init(self): | ||
74 | with mock.patch("market.ccxt.poloniexE.session") as session: | ||
75 | session.request.return_value = "response" | ||
76 | ccxt = market.ccxt.poloniexE() | ||
77 | ccxt._market = mock.Mock | ||
78 | ccxt._market.report = mock.Mock() | ||
79 | |||
80 | ccxt.session.request("GET", "URL", data="data", | ||
81 | headers="headers") | ||
82 | ccxt._market.report.log_http_request.assert_called_with('GET', 'URL', 'data', | ||
83 | 'headers', 'response') | ||
84 | |||
73 | def test_nanoseconds(self): | 85 | def test_nanoseconds(self): |
74 | with mock.patch.object(market.ccxt.time, "time") as time: | 86 | with mock.patch.object(market.ccxt.time, "time") as time: |
75 | time.return_value = 123456.7890123456 | 87 | time.return_value = 123456.7890123456 |
@@ -80,6 +92,58 @@ class poloniexETest(unittest.TestCase): | |||
80 | time.return_value = 123456.7890123456 | 92 | time.return_value = 123456.7890123456 |
81 | self.assertEqual(123456789012345, self.s.nonce()) | 93 | self.assertEqual(123456789012345, self.s.nonce()) |
82 | 94 | ||
95 | def test_request(self): | ||
96 | with mock.patch.object(market.ccxt.poloniex, "request") as request,\ | ||
97 | mock.patch("market.ccxt.retry_call") as retry_call: | ||
98 | with self.subTest(wrapped=True): | ||
99 | with self.subTest(desc="public"): | ||
100 | self.s.request("foo") | ||
101 | retry_call.assert_called_with(request, | ||
102 | delay=1, tries=10, fargs=["foo"], | ||
103 | fkwargs={'api': 'public', 'method': 'GET', 'params': {}, 'headers': None, 'body': None}, | ||
104 | exceptions=(market.ccxt.RequestTimeout,)) | ||
105 | request.assert_not_called() | ||
106 | |||
107 | with self.subTest(desc="private GET"): | ||
108 | self.s.request("foo", api="private") | ||
109 | retry_call.assert_called_with(request, | ||
110 | delay=1, tries=10, fargs=["foo"], | ||
111 | fkwargs={'api': 'private', 'method': 'GET', 'params': {}, 'headers': None, 'body': None}, | ||
112 | exceptions=(market.ccxt.RequestTimeout,)) | ||
113 | request.assert_not_called() | ||
114 | |||
115 | with self.subTest(desc="private POST regexp"): | ||
116 | self.s.request("returnFoo", api="private", method="POST") | ||
117 | retry_call.assert_called_with(request, | ||
118 | delay=1, tries=10, fargs=["returnFoo"], | ||
119 | fkwargs={'api': 'private', 'method': 'POST', 'params': {}, 'headers': None, 'body': None}, | ||
120 | exceptions=(market.ccxt.RequestTimeout,)) | ||
121 | request.assert_not_called() | ||
122 | |||
123 | with self.subTest(desc="private POST non-regexp"): | ||
124 | self.s.request("getMarginPosition", api="private", method="POST") | ||
125 | retry_call.assert_called_with(request, | ||
126 | delay=1, tries=10, fargs=["getMarginPosition"], | ||
127 | fkwargs={'api': 'private', 'method': 'POST', 'params': {}, 'headers': None, 'body': None}, | ||
128 | exceptions=(market.ccxt.RequestTimeout,)) | ||
129 | request.assert_not_called() | ||
130 | retry_call.reset_mock() | ||
131 | request.reset_mock() | ||
132 | with self.subTest(wrapped=False): | ||
133 | with self.subTest(desc="private POST non-matching regexp"): | ||
134 | self.s.request("marginBuy", api="private", method="POST") | ||
135 | request.assert_called_with("marginBuy", | ||
136 | api="private", method="POST", params={}, | ||
137 | headers=None, body=None) | ||
138 | retry_call.assert_not_called() | ||
139 | |||
140 | with self.subTest(desc="private POST non-matching non-regexp"): | ||
141 | self.s.request("closeMarginPositionOther", api="private", method="POST") | ||
142 | request.assert_called_with("closeMarginPositionOther", | ||
143 | api="private", method="POST", params={}, | ||
144 | headers=None, body=None) | ||
145 | retry_call.assert_not_called() | ||
146 | |||
83 | def test_order_precision(self): | 147 | def test_order_precision(self): |
84 | self.assertEqual(8, self.s.order_precision("FOO")) | 148 | self.assertEqual(8, self.s.order_precision("FOO")) |
85 | 149 | ||
@@ -1125,17 +1189,11 @@ class MarketTest(WebMockTestCase): | |||
1125 | def test_from_config(self, ccxt): | 1189 | def test_from_config(self, ccxt): |
1126 | with mock.patch("market.ReportStore"): | 1190 | with mock.patch("market.ReportStore"): |
1127 | ccxt.poloniexE.return_value = self.ccxt | 1191 | ccxt.poloniexE.return_value = self.ccxt |
1128 | self.ccxt.session.request.return_value = "response" | ||
1129 | 1192 | ||
1130 | m = market.Market.from_config({"key": "key", "secred": "secret"}, self.market_args()) | 1193 | m = market.Market.from_config({"key": "key", "secred": "secret"}, self.market_args()) |
1131 | 1194 | ||
1132 | self.assertEqual(self.ccxt, m.ccxt) | 1195 | self.assertEqual(self.ccxt, m.ccxt) |
1133 | 1196 | ||
1134 | self.ccxt.session.request("GET", "URL", data="data", | ||
1135 | headers="headers") | ||
1136 | m.report.log_http_request.assert_called_with('GET', 'URL', 'data', | ||
1137 | 'headers', 'response') | ||
1138 | |||
1139 | m = market.Market.from_config({"key": "key", "secred": "secret"}, self.market_args(debug=True)) | 1197 | m = market.Market.from_config({"key": "key", "secred": "secret"}, self.market_args(debug=True)) |
1140 | self.assertEqual(True, m.debug) | 1198 | self.assertEqual(True, m.debug) |
1141 | 1199 | ||