From: Ismaƫl Bouya Date: Sat, 24 Mar 2018 14:18:31 +0000 (+0100) Subject: Add retry facility for api call timeouts X-Git-Tag: v1.0^2~1^2~1 X-Git-Url: https://git.immae.eu/?p=perso%2FImmae%2FProjets%2FCryptomonnaies%2FCryptoportfolio%2FTrader.git;a=commitdiff_plain;h=c7c1e0b26821fdd5622f81fb456f1028d4c9ab09 Add retry facility for api call timeouts Fixes https://git.immae.eu/mantisbt/view.php?id=40 --- diff --git a/ccxt_wrapper.py b/ccxt_wrapper.py index d37c306..c500659 100644 --- a/ccxt_wrapper.py +++ b/ccxt_wrapper.py @@ -1,12 +1,42 @@ from ccxt import * import decimal import time +from retry.api import retry_call +import re def _cw_exchange_sum(self, *args): return sum([arg for arg in args if isinstance(arg, (float, int, decimal.Decimal))]) Exchange.sum = _cw_exchange_sum class poloniexE(poloniex): + RETRIABLE_CALLS = [ + re.compile(r"^return"), + re.compile(r"^cancel"), + re.compile(r"^closeMarginPosition$"), + re.compile(r"^getMarginPosition$"), + ] + + def request(self, path, api='public', method='GET', params={}, headers=None, body=None): + """ + Wrapped to allow retry of non-posting requests" + """ + + origin_request = super(poloniexE, self).request + kwargs = { + "api": api, + "method": method, + "params": params, + "headers": headers, + "body": body + } + + retriable = any(re.match(call, path) for call in self.RETRIABLE_CALLS) + if api == "public" or method == "GET" or retriable: + return retry_call(origin_request, fargs=[path], fkwargs=kwargs, + tries=10, delay=1, exceptions=(RequestTimeout,)) + else: + return origin_request(path, **kwargs) + @staticmethod def nanoseconds(): return int(time.time() * 1000000000) diff --git a/test.py b/test.py index 637a305..40c64a9 100644 --- a/test.py +++ b/test.py @@ -80,6 +80,58 @@ class poloniexETest(unittest.TestCase): time.return_value = 123456.7890123456 self.assertEqual(123456789012345, self.s.nonce()) + def test_request(self): + with mock.patch.object(market.ccxt.poloniex, "request") as request,\ + mock.patch("market.ccxt.retry_call") as retry_call: + with self.subTest(wrapped=True): + with self.subTest(desc="public"): + self.s.request("foo") + retry_call.assert_called_with(request, + delay=1, tries=10, fargs=["foo"], + fkwargs={'api': 'public', 'method': 'GET', 'params': {}, 'headers': None, 'body': None}, + exceptions=(market.ccxt.RequestTimeout,)) + request.assert_not_called() + + with self.subTest(desc="private GET"): + self.s.request("foo", api="private") + retry_call.assert_called_with(request, + delay=1, tries=10, fargs=["foo"], + fkwargs={'api': 'private', 'method': 'GET', 'params': {}, 'headers': None, 'body': None}, + exceptions=(market.ccxt.RequestTimeout,)) + request.assert_not_called() + + with self.subTest(desc="private POST regexp"): + self.s.request("returnFoo", api="private", method="POST") + retry_call.assert_called_with(request, + delay=1, tries=10, fargs=["returnFoo"], + fkwargs={'api': 'private', 'method': 'POST', 'params': {}, 'headers': None, 'body': None}, + exceptions=(market.ccxt.RequestTimeout,)) + request.assert_not_called() + + with self.subTest(desc="private POST non-regexp"): + self.s.request("getMarginPosition", api="private", method="POST") + retry_call.assert_called_with(request, + delay=1, tries=10, fargs=["getMarginPosition"], + fkwargs={'api': 'private', 'method': 'POST', 'params': {}, 'headers': None, 'body': None}, + exceptions=(market.ccxt.RequestTimeout,)) + request.assert_not_called() + retry_call.reset_mock() + request.reset_mock() + with self.subTest(wrapped=False): + with self.subTest(desc="private POST non-matching regexp"): + self.s.request("marginBuy", api="private", method="POST") + request.assert_called_with("marginBuy", + api="private", method="POST", params={}, + headers=None, body=None) + retry_call.assert_not_called() + + with self.subTest(desc="private POST non-matching non-regexp"): + self.s.request("closeMarginPositionOther", api="private", method="POST") + request.assert_called_with("closeMarginPositionOther", + api="private", method="POST", params={}, + headers=None, body=None) + retry_call.assert_not_called() + def test_order_precision(self): self.assertEqual(8, self.s.order_precision("FOO"))