]> git.immae.eu Git - perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git/commitdiff
Merge branch 'retry_timeout' into dev
authorIsmaël Bouya <ismael.bouya@normalesup.org>
Sat, 24 Mar 2018 15:07:11 +0000 (16:07 +0100)
committerIsmaël Bouya <ismael.bouya@normalesup.org>
Sat, 24 Mar 2018 15:07:11 +0000 (16:07 +0100)
ccxt_wrapper.py
market.py
test.py

index d37c306882aaae72fa7c56470eada65f77ef5fa1..4ed37d9376e53f9935c21db0c40944eb9daef95b 100644 (file)
@@ -1,12 +1,57 @@
 from ccxt import *
 import decimal
 import time
+from retry.api import retry_call
+import re
 
 def _cw_exchange_sum(self, *args):
     return sum([arg for arg in args if isinstance(arg, (float, int, decimal.Decimal))])
 Exchange.sum = _cw_exchange_sum
 
 class poloniexE(poloniex):
+    RETRIABLE_CALLS = [
+            re.compile(r"^return"),
+            re.compile(r"^cancel"),
+            re.compile(r"^closeMarginPosition$"),
+            re.compile(r"^getMarginPosition$"),
+            ]
+
+    def request(self, path, api='public', method='GET', params={}, headers=None, body=None):
+        """
+        Wrapped to allow retry of non-posting requests"
+        """
+
+        origin_request = super(poloniexE, self).request
+        kwargs = {
+                "api": api,
+                "method": method,
+                "params": params,
+                "headers": headers,
+                "body": body
+                }
+
+        retriable = any(re.match(call, path) for call in self.RETRIABLE_CALLS)
+        if api == "public" or method == "GET" or retriable:
+            return retry_call(origin_request, fargs=[path], fkwargs=kwargs,
+                    tries=10, delay=1, exceptions=(RequestTimeout,))
+        else:
+            return origin_request(path, **kwargs)
+
+    def __init__(self, *args, **kwargs):
+        super(poloniexE, self).__init__(*args, **kwargs)
+
+        # For requests logging
+        self.session.origin_request = self.session.request
+        self.session._parent = self
+
+        def request_wrap(self, *args, **kwargs):
+            r = self.origin_request(*args, **kwargs)
+            self._parent._market.report.log_http_request(args[0],
+                    args[1], kwargs["data"], kwargs["headers"], r)
+            return r
+        self.session.request = request_wrap.__get__(self.session,
+                self.session.__class__)
+
     @staticmethod
     def nanoseconds():
         return int(time.time() * 1000000000)
index 496ec45843319f2145f955e811ea49e1842e4761..055967cd355a0c05c37bdd64e2e6c4f494949073 100644 (file)
--- a/market.py
+++ b/market.py
@@ -33,18 +33,6 @@ class Market:
 
         ccxt_instance = ccxt.poloniexE(config)
 
-        # For requests logging
-        ccxt_instance.session.origin_request = ccxt_instance.session.request
-        ccxt_instance.session._parent = ccxt_instance
-
-        def request_wrap(self, *args, **kwargs):
-            r = self.origin_request(*args, **kwargs)
-            self._parent._market.report.log_http_request(args[0],
-                    args[1], kwargs["data"], kwargs["headers"], r)
-            return r
-        ccxt_instance.session.request = request_wrap.__get__(ccxt_instance.session,
-                ccxt_instance.session.__class__)
-
         return cls(ccxt_instance, args, **kwargs)
 
     def store_report(self):
diff --git a/test.py b/test.py
index 637a3054c4d8449b3fb0963e4180e62b0b837567..18616c1c848620d93a19d61c72b4edcd7c773ea1 100644 (file)
--- a/test.py
+++ b/test.py
@@ -70,6 +70,18 @@ class poloniexETest(unittest.TestCase):
         self.wm.stop()
         super(poloniexETest, self).tearDown()
 
+    def test__init(self):
+        with mock.patch("market.ccxt.poloniexE.session") as session:
+            session.request.return_value = "response"
+            ccxt = market.ccxt.poloniexE()
+            ccxt._market = mock.Mock
+            ccxt._market.report = mock.Mock()
+
+            ccxt.session.request("GET", "URL", data="data",
+                    headers="headers")
+            ccxt._market.report.log_http_request.assert_called_with('GET', 'URL', 'data',
+                    'headers', 'response')
+
     def test_nanoseconds(self):
         with mock.patch.object(market.ccxt.time, "time") as time:
             time.return_value = 123456.7890123456
@@ -80,6 +92,58 @@ class poloniexETest(unittest.TestCase):
             time.return_value = 123456.7890123456
             self.assertEqual(123456789012345, self.s.nonce())
 
+    def test_request(self):
+        with mock.patch.object(market.ccxt.poloniex, "request") as request,\
+                mock.patch("market.ccxt.retry_call") as retry_call:
+            with self.subTest(wrapped=True):
+                with self.subTest(desc="public"):
+                    self.s.request("foo")
+                    retry_call.assert_called_with(request,
+                            delay=1, tries=10, fargs=["foo"],
+                            fkwargs={'api': 'public', 'method': 'GET', 'params': {}, 'headers': None, 'body': None},
+                            exceptions=(market.ccxt.RequestTimeout,))
+                    request.assert_not_called()
+
+                with self.subTest(desc="private GET"):
+                    self.s.request("foo", api="private")
+                    retry_call.assert_called_with(request,
+                            delay=1, tries=10, fargs=["foo"],
+                            fkwargs={'api': 'private', 'method': 'GET', 'params': {}, 'headers': None, 'body': None},
+                            exceptions=(market.ccxt.RequestTimeout,))
+                    request.assert_not_called()
+
+                with self.subTest(desc="private POST regexp"):
+                    self.s.request("returnFoo", api="private", method="POST")
+                    retry_call.assert_called_with(request,
+                            delay=1, tries=10, fargs=["returnFoo"],
+                            fkwargs={'api': 'private', 'method': 'POST', 'params': {}, 'headers': None, 'body': None},
+                            exceptions=(market.ccxt.RequestTimeout,))
+                    request.assert_not_called()
+
+                with self.subTest(desc="private POST non-regexp"):
+                    self.s.request("getMarginPosition", api="private", method="POST")
+                    retry_call.assert_called_with(request,
+                            delay=1, tries=10, fargs=["getMarginPosition"],
+                            fkwargs={'api': 'private', 'method': 'POST', 'params': {}, 'headers': None, 'body': None},
+                            exceptions=(market.ccxt.RequestTimeout,))
+                    request.assert_not_called()
+            retry_call.reset_mock()
+            request.reset_mock()
+            with self.subTest(wrapped=False):
+                with self.subTest(desc="private POST non-matching regexp"):
+                    self.s.request("marginBuy", api="private", method="POST")
+                    request.assert_called_with("marginBuy",
+                            api="private", method="POST", params={},
+                            headers=None, body=None)
+                    retry_call.assert_not_called()
+
+                with self.subTest(desc="private POST non-matching non-regexp"):
+                    self.s.request("closeMarginPositionOther", api="private", method="POST")
+                    request.assert_called_with("closeMarginPositionOther",
+                            api="private", method="POST", params={},
+                            headers=None, body=None)
+                    retry_call.assert_not_called()
+
     def test_order_precision(self):
         self.assertEqual(8, self.s.order_precision("FOO"))
 
@@ -1125,17 +1189,11 @@ class MarketTest(WebMockTestCase):
     def test_from_config(self, ccxt):
         with mock.patch("market.ReportStore"):
             ccxt.poloniexE.return_value = self.ccxt
-            self.ccxt.session.request.return_value = "response"
 
             m = market.Market.from_config({"key": "key", "secred": "secret"}, self.market_args())
 
             self.assertEqual(self.ccxt, m.ccxt)
 
-            self.ccxt.session.request("GET", "URL", data="data",
-                    headers="headers")
-            m.report.log_http_request.assert_called_with('GET', 'URL', 'data',
-                    'headers', 'response')
-
         m = market.Market.from_config({"key": "key", "secred": "secret"}, self.market_args(debug=True))
         self.assertEqual(True, m.debug)