]> git.immae.eu Git - perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git/blobdiff - test.py
Add parallelization
[perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git] / test.py
diff --git a/test.py b/test.py
index c0e6a8aa68aa30c7a06dc42f39164de1d7a6ca63..f61e739d9a3e664cd30bfca1d185d13846c46cee 100644 (file)
--- a/test.py
+++ b/test.py
@@ -7,6 +7,7 @@ from unittest import mock
 import requests
 import requests_mock
 from io import StringIO
+import threading
 import portfolio, market, main, store
 
 limits = ["acceptance", "unit"]
@@ -33,10 +34,14 @@ class WebMockTestCase(unittest.TestCase):
 
         self.patchers = [
                 mock.patch.multiple(market.Portfolio,
-                    last_date=None,
-                    data=None,
-                    liquidities={},
-                    report=mock.Mock()),
+                    data=store.LockedVar(None),
+                    liquidities=store.LockedVar({}),
+                    last_date=store.LockedVar(None),
+                    report=mock.Mock(),
+                    worker=None,
+                    worker_notify=None,
+                    worker_started=False,
+                    callback=None),
                 mock.patch.multiple(portfolio.Computation,
                     computations=portfolio.Computation.computations),
                 ]
@@ -441,6 +446,99 @@ class poloniexETest(unittest.TestCase):
 
             create_order.assert_called_once_with("symbol", "type", "side", "amount", price="price", params="params")
 
+@unittest.skipUnless("unit" in limits, "Unit skipped")
+class NoopLockTest(unittest.TestCase):
+    def test_with(self):
+        noop_lock = store.NoopLock()
+        with noop_lock:
+            self.assertTrue(True)
+
+@unittest.skipUnless("unit" in limits, "Unit skipped")
+class LockedVar(unittest.TestCase):
+
+    def test_values(self):
+        locked_var = store.LockedVar("Foo")
+        self.assertIsInstance(locked_var.lock, store.NoopLock)
+        self.assertEqual("Foo", locked_var.val)
+
+    def test_get(self):
+        with self.subTest(desc="Normal case"):
+            locked_var = store.LockedVar("Foo")
+            self.assertEqual("Foo", locked_var.get())
+        with self.subTest(desc="Dict"):
+            locked_var = store.LockedVar({"foo": "bar"})
+            self.assertEqual({"foo": "bar"}, locked_var.get())
+            self.assertEqual("bar", locked_var.get("foo"))
+            self.assertIsNone(locked_var.get("other"))
+
+    def test_set(self):
+        locked_var = store.LockedVar("Foo")
+        locked_var.set("Bar")
+        self.assertEqual("Bar", locked_var.get())
+
+    def test__getattr(self):
+        dummy = type('Dummy', (object,), {})()
+        dummy.attribute = "Hey"
+
+        locked_var = store.LockedVar(dummy)
+        self.assertEqual("Hey", locked_var.attribute)
+        with self.assertRaises(AttributeError):
+            locked_var.other
+
+    def test_start_lock(self):
+        locked_var = store.LockedVar("Foo")
+        locked_var.start_lock()
+        self.assertEqual("lock", locked_var.lock.__class__.__name__)
+
+        thread1 = threading.Thread(target=locked_var.set, args=["Bar1"])
+        thread2 = threading.Thread(target=locked_var.set, args=["Bar2"])
+        thread3 = threading.Thread(target=locked_var.set, args=["Bar3"])
+
+        with locked_var.lock:
+            thread1.start()
+            thread2.start()
+            thread3.start()
+
+            self.assertEqual("Foo", locked_var.val)
+        thread1.join()
+        thread2.join()
+        thread3.join()
+        self.assertEqual("Bar", locked_var.get()[0:3])
+
+    def test_wait_for_notification(self):
+        with self.assertRaises(RuntimeError):
+            store.Portfolio.wait_for_notification()
+
+        with mock.patch.object(store.Portfolio, "get_cryptoportfolio") as get,\
+                mock.patch.object(store.Portfolio, "report") as report,\
+                mock.patch.object(store.time, "sleep") as sleep:
+            store.Portfolio.start_worker(poll=3)
+
+            store.Portfolio.worker_notify.set()
+
+            store.Portfolio.callback.wait()
+
+            report.print_log.assert_called_once_with("Fetching cryptoportfolio")
+            get.assert_called_once_with(refetch=True)
+            sleep.assert_called_once_with(3)
+            self.assertFalse(store.Portfolio.worker_notify.is_set())
+            self.assertTrue(store.Portfolio.worker.is_alive())
+
+            store.Portfolio.callback.clear()
+            store.Portfolio.worker_started = False
+            store.Portfolio.worker_notify.set()
+            store.Portfolio.callback.wait()
+
+            self.assertFalse(store.Portfolio.worker.is_alive())
+
+    def test_notify_and_wait(self):
+        with mock.patch.object(store.Portfolio, "callback") as callback,\
+                mock.patch.object(store.Portfolio, "worker_notify") as worker_notify:
+            store.Portfolio.notify_and_wait()
+            callback.clear.assert_called_once_with()
+            worker_notify.set.assert_called_once_with()
+            callback.wait.assert_called_once_with()
+
 @unittest.skipUnless("unit" in limits, "Unit skipped")
 class PortfolioTest(WebMockTestCase):
     def setUp(self):
@@ -453,86 +551,131 @@ class PortfolioTest(WebMockTestCase):
 
     @mock.patch.object(market.Portfolio, "parse_cryptoportfolio")
     def test_get_cryptoportfolio(self, parse_cryptoportfolio):
-        self.wm.get(market.Portfolio.URL, [
-            {"text":'{ "foo": "bar" }', "status_code": 200},
-            {"text": "System Error", "status_code": 500},
-            {"exc": requests.exceptions.ConnectTimeout},
-            ])
-        market.Portfolio.get_cryptoportfolio()
-        self.assertIn("foo", market.Portfolio.data)
-        self.assertEqual("bar", market.Portfolio.data["foo"])
-        self.assertTrue(self.wm.called)
-        self.assertEqual(1, self.wm.call_count)
-        market.Portfolio.report.log_error.assert_not_called()
-        market.Portfolio.report.log_http_request.assert_called_once()
-        parse_cryptoportfolio.assert_called_once_with()
-        market.Portfolio.report.log_http_request.reset_mock()
-        parse_cryptoportfolio.reset_mock()
-        market.Portfolio.data = None
-
-        market.Portfolio.get_cryptoportfolio()
-        self.assertIsNone(market.Portfolio.data)
-        self.assertEqual(2, self.wm.call_count)
-        parse_cryptoportfolio.assert_not_called()
-        market.Portfolio.report.log_error.assert_not_called()
-        market.Portfolio.report.log_http_request.assert_called_once()
-        market.Portfolio.report.log_http_request.reset_mock()
-        parse_cryptoportfolio.reset_mock()
-
-        market.Portfolio.data = "Foo"
-        market.Portfolio.get_cryptoportfolio()
-        self.assertEqual(2, self.wm.call_count)
-        parse_cryptoportfolio.assert_not_called()
-
-        market.Portfolio.get_cryptoportfolio(refetch=True)
-        self.assertEqual("Foo", market.Portfolio.data)
-        self.assertEqual(3, self.wm.call_count)
-        market.Portfolio.report.log_error.assert_called_once_with("get_cryptoportfolio",
-                exception=mock.ANY)
-        market.Portfolio.report.log_http_request.assert_not_called()
+        with self.subTest(parallel=False):
+            self.wm.get(market.Portfolio.URL, [
+                {"text":'{ "foo": "bar" }', "status_code": 200},
+                {"text": "System Error", "status_code": 500},
+                {"exc": requests.exceptions.ConnectTimeout},
+                ])
+            market.Portfolio.get_cryptoportfolio()
+            self.assertIn("foo", market.Portfolio.data.get())
+            self.assertEqual("bar", market.Portfolio.data.get()["foo"])
+            self.assertTrue(self.wm.called)
+            self.assertEqual(1, self.wm.call_count)
+            market.Portfolio.report.log_error.assert_not_called()
+            market.Portfolio.report.log_http_request.assert_called_once()
+            parse_cryptoportfolio.assert_called_once_with()
+            market.Portfolio.report.log_http_request.reset_mock()
+            parse_cryptoportfolio.reset_mock()
+            market.Portfolio.data = store.LockedVar(None)
+
+            market.Portfolio.get_cryptoportfolio()
+            self.assertIsNone(market.Portfolio.data.get())
+            self.assertEqual(2, self.wm.call_count)
+            parse_cryptoportfolio.assert_not_called()
+            market.Portfolio.report.log_error.assert_not_called()
+            market.Portfolio.report.log_http_request.assert_called_once()
+            market.Portfolio.report.log_http_request.reset_mock()
+            parse_cryptoportfolio.reset_mock()
+
+            market.Portfolio.data = store.LockedVar("Foo")
+            market.Portfolio.get_cryptoportfolio()
+            self.assertEqual(2, self.wm.call_count)
+            parse_cryptoportfolio.assert_not_called()
+
+            market.Portfolio.get_cryptoportfolio(refetch=True)
+            self.assertEqual("Foo", market.Portfolio.data.get())
+            self.assertEqual(3, self.wm.call_count)
+            market.Portfolio.report.log_error.assert_called_once_with("get_cryptoportfolio",
+                    exception=mock.ANY)
+            market.Portfolio.report.log_http_request.assert_not_called()
+        with self.subTest(parallel=True):
+            with mock.patch.object(market.Portfolio, "is_worker_thread") as is_worker,\
+                    mock.patch.object(market.Portfolio, "notify_and_wait") as notify:
+                with self.subTest(worker=True):
+                    market.Portfolio.data = store.LockedVar(None)
+                    market.Portfolio.worker = mock.Mock()
+                    is_worker.return_value = True
+                    self.wm.get(market.Portfolio.URL, [
+                        {"text":'{ "foo": "bar" }', "status_code": 200},
+                        ])
+                    market.Portfolio.get_cryptoportfolio()
+                    self.assertIn("foo", market.Portfolio.data.get())
+                parse_cryptoportfolio.reset_mock()
+                with self.subTest(worker=False):
+                    market.Portfolio.data = store.LockedVar(None)
+                    market.Portfolio.worker = mock.Mock()
+                    is_worker.return_value = False
+                    market.Portfolio.get_cryptoportfolio()
+                    notify.assert_called_once_with()
+                    parse_cryptoportfolio.assert_not_called()
 
     def test_parse_cryptoportfolio(self):
-        market.Portfolio.data = store.json.loads(self.json_response, parse_int=D,
-                parse_float=D)
-        market.Portfolio.parse_cryptoportfolio()
-
-        self.assertListEqual(
-                ["medium", "high"],
-                list(market.Portfolio.liquidities.keys()))
-
-        liquidities = market.Portfolio.liquidities
-        self.assertEqual(10, len(liquidities["medium"].keys()))
-        self.assertEqual(10, len(liquidities["high"].keys()))
-
-        expected = {
-                'BTC':  (D("0.2857"), "long"),
-                'DGB':  (D("0.1015"), "long"),
-                'DOGE': (D("0.1805"), "long"),
-                'SC':   (D("0.0623"), "long"),
-                'ZEC':  (D("0.3701"), "long"),
-                }
-        date = portfolio.datetime(2018, 1, 8)
-        self.assertDictEqual(expected, liquidities["high"][date])
-
-        expected = {
-                'BTC':  (D("1.1102e-16"), "long"),
-                'ETC':  (D("0.1"), "long"),
-                'FCT':  (D("0.1"), "long"),
-                'GAS':  (D("0.1"), "long"),
-                'NAV':  (D("0.1"), "long"),
-                'OMG':  (D("0.1"), "long"),
-                'OMNI': (D("0.1"), "long"),
-                'PPC':  (D("0.1"), "long"),
-                'RIC':  (D("0.1"), "long"),
-                'VIA':  (D("0.1"), "long"),
-                'XCP':  (D("0.1"), "long"),
-                }
-        self.assertDictEqual(expected, liquidities["medium"][date])
-        self.assertEqual(portfolio.datetime(2018, 1, 15), market.Portfolio.last_date)
+        with self.subTest(description="Normal case"):
+            market.Portfolio.data = store.LockedVar(store.json.loads(
+                self.json_response, parse_int=D, parse_float=D))
+            market.Portfolio.parse_cryptoportfolio()
+
+            self.assertListEqual(
+                    ["medium", "high"],
+                    list(market.Portfolio.liquidities.get().keys()))
+
+            liquidities = market.Portfolio.liquidities.get()
+            self.assertEqual(10, len(liquidities["medium"].keys()))
+            self.assertEqual(10, len(liquidities["high"].keys()))
+
+            expected = {
+                    'BTC':  (D("0.2857"), "long"),
+                    'DGB':  (D("0.1015"), "long"),
+                    'DOGE': (D("0.1805"), "long"),
+                    'SC':   (D("0.0623"), "long"),
+                    'ZEC':  (D("0.3701"), "long"),
+                    }
+            date = portfolio.datetime(2018, 1, 8)
+            self.assertDictEqual(expected, liquidities["high"][date])
+
+            expected = {
+                    'BTC':  (D("1.1102e-16"), "long"),
+                    'ETC':  (D("0.1"), "long"),
+                    'FCT':  (D("0.1"), "long"),
+                    'GAS':  (D("0.1"), "long"),
+                    'NAV':  (D("0.1"), "long"),
+                    'OMG':  (D("0.1"), "long"),
+                    'OMNI': (D("0.1"), "long"),
+                    'PPC':  (D("0.1"), "long"),
+                    'RIC':  (D("0.1"), "long"),
+                    'VIA':  (D("0.1"), "long"),
+                    'XCP':  (D("0.1"), "long"),
+                    }
+            self.assertDictEqual(expected, liquidities["medium"][date])
+            self.assertEqual(portfolio.datetime(2018, 1, 15), market.Portfolio.last_date.get())
+
+        with self.subTest(description="Missing weight"):
+            data = store.json.loads(self.json_response, parse_int=D, parse_float=D)
+            del(data["portfolio_2"]["weights"])
+            market.Portfolio.data = store.LockedVar(data)
+
+            market.Portfolio.parse_cryptoportfolio()
+            self.assertListEqual(
+                    ["medium", "high"],
+                    list(market.Portfolio.liquidities.get().keys()))
+            self.assertEqual({}, market.Portfolio.liquidities.get("medium"))
+
+        with self.subTest(description="All missing weights"):
+            data = store.json.loads(self.json_response, parse_int=D, parse_float=D)
+            del(data["portfolio_1"]["weights"])
+            del(data["portfolio_2"]["weights"])
+            market.Portfolio.data = store.LockedVar(data)
+
+            market.Portfolio.parse_cryptoportfolio()
+            self.assertEqual({}, market.Portfolio.liquidities.get("medium"))
+            self.assertEqual({}, market.Portfolio.liquidities.get("high"))
+            self.assertEqual(datetime.datetime(1,1,1), market.Portfolio.last_date.get())
+
 
     @mock.patch.object(market.Portfolio, "get_cryptoportfolio")
     def test_repartition(self, get_cryptoportfolio):
-        market.Portfolio.liquidities = {
+        market.Portfolio.liquidities = store.LockedVar({
                 "medium": {
                     "2018-03-01": "medium_2018-03-01",
                     "2018-03-08": "medium_2018-03-08",
@@ -541,8 +684,8 @@ class PortfolioTest(WebMockTestCase):
                     "2018-03-01": "high_2018-03-01",
                     "2018-03-08": "high_2018-03-08",
                     }
-                }
-        market.Portfolio.last_date = "2018-03-08"
+                })
+        market.Portfolio.last_date = store.LockedVar("2018-03-08")
 
         self.assertEqual("medium_2018-03-08", market.Portfolio.repartition())
         get_cryptoportfolio.assert_called_once_with()
@@ -559,9 +702,9 @@ class PortfolioTest(WebMockTestCase):
             else:
                 self.assertFalse(refetch)
             self.call_count += 1
-            market.Portfolio.last_date = store.datetime.now()\
+            market.Portfolio.last_date = store.LockedVar(store.datetime.now()\
                 - store.timedelta(10)\
-                + store.timedelta(self.call_count)
+                + store.timedelta(self.call_count))
         get_cryptoportfolio.side_effect = _get
 
         market.Portfolio.wait_for_recent()
@@ -572,7 +715,7 @@ class PortfolioTest(WebMockTestCase):
 
         sleep.reset_mock()
         get_cryptoportfolio.reset_mock()
-        market.Portfolio.last_date = None
+        market.Portfolio.last_date = store.LockedVar(None)
         self.call_count = 0
         market.Portfolio.wait_for_recent(delta=15)
         sleep.assert_not_called()
@@ -580,13 +723,45 @@ class PortfolioTest(WebMockTestCase):
 
         sleep.reset_mock()
         get_cryptoportfolio.reset_mock()
-        market.Portfolio.last_date = None
+        market.Portfolio.last_date = store.LockedVar(None)
         self.call_count = 0
         market.Portfolio.wait_for_recent(delta=1)
         sleep.assert_called_with(30)
         self.assertEqual(9, sleep.call_count)
         self.assertEqual(10, get_cryptoportfolio.call_count)
 
+    def test_is_worker_thread(self):
+        with self.subTest(worker=None):
+            self.assertFalse(store.Portfolio.is_worker_thread())
+
+        with self.subTest(worker="not self"),\
+                mock.patch("threading.current_thread") as current_thread:
+            current = mock.Mock()
+            current_thread.return_value = current
+            store.Portfolio.worker = mock.Mock()
+            self.assertFalse(store.Portfolio.is_worker_thread())
+
+        with self.subTest(worker="self"),\
+                mock.patch("threading.current_thread") as current_thread:
+            current = mock.Mock()
+            current_thread.return_value = current
+            store.Portfolio.worker = current
+            self.assertTrue(store.Portfolio.is_worker_thread())
+
+    def test_start_worker(self):
+        with mock.patch.object(store.Portfolio, "wait_for_notification") as notification:
+            store.Portfolio.start_worker()
+            notification.assert_called_once_with(poll=30)
+
+            self.assertEqual("lock", store.Portfolio.last_date.lock.__class__.__name__)
+            self.assertEqual("lock", store.Portfolio.liquidities.lock.__class__.__name__)
+            store.Portfolio.report.start_lock.assert_called_once_with()
+
+            self.assertIsNotNone(store.Portfolio.worker)
+            self.assertIsNotNone(store.Portfolio.worker_notify)
+            self.assertIsNotNone(store.Portfolio.callback)
+            self.assertTrue(store.Portfolio.worker_started)
+
 @unittest.skipUnless("unit" in limits, "Unit skipped")
 class AmountTest(WebMockTestCase):
     def test_values(self):
@@ -3362,31 +3537,64 @@ class MainTest(WebMockTestCase):
                 self.assertEqual("Exception: boo\n", stdout_mock.getvalue())
 
     def test_main(self):
-        with mock.patch("main.parse_args") as parse_args,\
-                mock.patch("main.parse_config") as parse_config,\
-                mock.patch("main.fetch_markets") as fetch_markets,\
-                mock.patch("main.process") as process:
+        with self.subTest(parallel=False):
+            with mock.patch("main.parse_args") as parse_args,\
+                    mock.patch("main.parse_config") as parse_config,\
+                    mock.patch("main.fetch_markets") as fetch_markets,\
+                    mock.patch("main.process") as process:
 
-            args_mock = mock.Mock()
-            args_mock.config = "config"
-            args_mock.user = "user"
-            parse_args.return_value = args_mock
+                args_mock = mock.Mock()
+                args_mock.parallel = False
+                args_mock.config = "config"
+                args_mock.user = "user"
+                parse_args.return_value = args_mock
 
-            parse_config.return_value = ["pg_config", "report_path"]
+                parse_config.return_value = ["pg_config", "report_path"]
 
-            fetch_markets.return_value = [["config1", 1], ["config2", 2]]
+                fetch_markets.return_value = [["config1", 1], ["config2", 2]]
 
-            main.main(["Foo", "Bar"])
+                main.main(["Foo", "Bar"])
 
-            parse_args.assert_called_with(["Foo", "Bar"])
-            parse_config.assert_called_with("config")
-            fetch_markets.assert_called_with("pg_config", "user")
+                parse_args.assert_called_with(["Foo", "Bar"])
+                parse_config.assert_called_with("config")
+                fetch_markets.assert_called_with("pg_config", "user")
 
-            self.assertEqual(2, process.call_count)
-            process.assert_has_calls([
-                mock.call("config1", 1, "report_path", args_mock),
-                mock.call("config2", 2, "report_path", args_mock),
-                ])
+                self.assertEqual(2, process.call_count)
+                process.assert_has_calls([
+                    mock.call("config1", 1, "report_path", args_mock),
+                    mock.call("config2", 2, "report_path", args_mock),
+                    ])
+        with self.subTest(parallel=True):
+            with mock.patch("main.parse_args") as parse_args,\
+                    mock.patch("main.parse_config") as parse_config,\
+                    mock.patch("main.fetch_markets") as fetch_markets,\
+                    mock.patch("main.process") as process,\
+                    mock.patch("store.Portfolio.start_worker") as start:
+
+                args_mock = mock.Mock()
+                args_mock.parallel = True
+                args_mock.config = "config"
+                args_mock.user = "user"
+                parse_args.return_value = args_mock
+
+                parse_config.return_value = ["pg_config", "report_path"]
+
+                fetch_markets.return_value = [["config1", 1], ["config2", 2]]
+
+                main.main(["Foo", "Bar"])
+
+                parse_args.assert_called_with(["Foo", "Bar"])
+                parse_config.assert_called_with("config")
+                fetch_markets.assert_called_with("pg_config", "user")
+
+                start.assert_called_once_with()
+                self.assertEqual(2, process.call_count)
+                process.assert_has_calls([
+                    mock.call.__bool__(),
+                    mock.call("config1", 1, "report_path", args_mock),
+                    mock.call.__bool__(),
+                    mock.call("config2", 2, "report_path", args_mock),
+                    ])
 
     @mock.patch.object(main.sys, "exit")
     @mock.patch("main.configparser")
@@ -3551,7 +3759,7 @@ class ProcessorTest(WebMockTestCase):
 
         method, arguments = processor.method_arguments("wait_for_recent")
         self.assertEqual(market.Portfolio.wait_for_recent, method)
-        self.assertEqual(["delta"], arguments)
+        self.assertEqual(["delta", "poll"], arguments)
 
         method, arguments = processor.method_arguments("prepare_trades")
         self.assertEqual(m.prepare_trades, method)