]> git.immae.eu Git - perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git/commitdiff
Merge branch 'refactor_db' into dev
authorIsmaël Bouya <ismael.bouya@normalesup.org>
Wed, 2 May 2018 22:23:26 +0000 (00:23 +0200)
committerIsmaël Bouya <ismael.bouya@normalesup.org>
Wed, 2 May 2018 22:23:26 +0000 (00:23 +0200)
dbs.py [new file with mode: 0644]
main.py
market.py
store.py
test.py
tests/helper.py
tests/test_dbs.py [new file with mode: 0644]
tests/test_main.py
tests/test_market.py
tests/test_store.py

diff --git a/dbs.py b/dbs.py
new file mode 100644 (file)
index 0000000..b32afa3
--- /dev/null
+++ b/dbs.py
@@ -0,0 +1,55 @@
+import psycopg2
+import redis as _redis
+
+redis = None
+psql = None
+
+def redis_connected():
+    global redis
+    if redis is None:
+        return False
+    else:
+        try:
+            return redis.ping()
+        except Exception:
+            return False
+
+def psql_connected():
+    global psql
+    return psql is not None and psql.closed == 0
+
+def connect_redis(args):
+    global redis
+
+    redis_config = {
+            "host": args.redis_host,
+            "port": args.redis_port,
+            "db": args.redis_database,
+            }
+    if redis_config["host"].startswith("/"):
+        redis_config["unix_socket_path"] = redis_config.pop("host")
+        del(redis_config["port"])
+    del(args.redis_host)
+    del(args.redis_port)
+    del(args.redis_database)
+
+    if redis is None:
+        redis = _redis.Redis(**redis_config)
+
+def connect_psql(args):
+    global psql
+    pg_config = {
+            "host": args.db_host,
+            "port": args.db_port,
+            "user": args.db_user,
+            "password": args.db_password,
+            "database": args.db_database,
+            }
+    del(args.db_host)
+    del(args.db_port)
+    del(args.db_user)
+    del(args.db_password)
+    del(args.db_database)
+
+    if psql is None:
+        psql = psycopg2.connect(**pg_config)
diff --git a/main.py b/main.py
index a4612075e53cae929e4b5906b61d8a630345ac85..ee25182f24d54bd3326bd976d579c07c0ee896df 100644 (file)
--- a/main.py
+++ b/main.py
@@ -1,5 +1,5 @@
 import configargparse
-import psycopg2
+import dbs
 import os
 import sys
 
@@ -63,15 +63,12 @@ def get_user_market(config_path, user_id, debug=False):
     if debug:
         args.append("--debug")
     args = parse_args(args)
-    pg_config, redis_config = parse_config(args)
-    market_id, market_config, user_id = list(fetch_markets(pg_config, str(user_id)))[0]
-    return market.Market.from_config(market_config, args,
-            pg_config=pg_config, market_id=market_id,
-            user_id=user_id)
+    parse_config(args)
+    market_id, market_config, user_id = list(fetch_markets(str(user_id)))[0]
+    return market.Market.from_config(market_config, args, user_id=user_id)
 
-def fetch_markets(pg_config, user):
-    connection = psycopg2.connect(**pg_config)
-    cursor = connection.cursor()
+def fetch_markets(user):
+    cursor = dbs.psql.cursor()
 
     if user is None:
         cursor.execute("SELECT id,config,user_id FROM market_configs")
@@ -82,30 +79,11 @@ def fetch_markets(pg_config, user):
         yield row
 
 def parse_config(args):
-    pg_config = {
-            "host": args.db_host,
-            "port": args.db_port,
-            "user": args.db_user,
-            "password": args.db_password,
-            "database": args.db_database,
-            }
-    del(args.db_host)
-    del(args.db_port)
-    del(args.db_user)
-    del(args.db_password)
-    del(args.db_database)
-
-    redis_config = {
-            "host": args.redis_host,
-            "port": args.redis_port,
-            "db": args.redis_database,
-            }
-    if redis_config["host"].startswith("/"):
-        redis_config["unix_socket_path"] = redis_config.pop("host")
-        del(redis_config["port"])
-    del(args.redis_host)
-    del(args.redis_port)
-    del(args.redis_database)
+    if args.db_host is not None:
+        dbs.connect_psql(args)
+
+    if args.redis_host is not None:
+        dbs.connect_redis(args)
 
     report_path = args.report_path
 
@@ -113,8 +91,6 @@ def parse_config(args):
             os.path.exists(report_path):
         os.makedirs(report_path)
 
-    return pg_config, redis_config
-
 def parse_args(argv):
     parser = configargparse.ArgumentParser(
             description="Run the trade bot.")
@@ -176,11 +152,10 @@ def parse_args(argv):
         parsed.action = ["sell_all"]
     return parsed
 
-def process(market_config, market_id, user_id, args, pg_config, redis_config):
+def process(market_config, market_id, user_id, args):
     try:
         market.Market\
                 .from_config(market_config, args, market_id=market_id,
-                        pg_config=pg_config, redis_config=redis_config,
                         user_id=user_id)\
                 .process(args.action, before=args.before, after=args.after)
     except Exception as e:
@@ -189,7 +164,7 @@ def process(market_config, market_id, user_id, args, pg_config, redis_config):
 def main(argv):
     args = parse_args(argv)
 
-    pg_config, redis_config = parse_config(args)
+    parse_config(args)
 
     market.Portfolio.report.set_verbose(not args.quiet)
 
@@ -205,8 +180,8 @@ def main(argv):
     else:
         process_ = process
 
-    for market_id, market_config, user_id in fetch_markets(pg_config, args.user):
-        process_(market_config, market_id, user_id, args, pg_config, redis_config)
+    for market_id, market_config, user_id in fetch_markets(args.user):
+        process_(market_config, market_id, user_id, args)
 
     if args.parallel:
         for thread in threads:
index fc6f9f667fb320e469f7fa103d398d2c40d184b2..5876071875750df01a1d658698216b2c30670bb0 100644 (file)
--- a/market.py
+++ b/market.py
@@ -1,8 +1,7 @@
 from ccxt import ExchangeError, NotSupported, RequestTimeout, InvalidNonce
 import ccxt_wrapper as ccxt
 import time
-import psycopg2
-import redis
+import dbs
 from store import *
 from cachetools.func import ttl_cache
 from datetime import datetime
@@ -27,7 +26,7 @@ class Market:
         self.balances = BalanceStore(self)
         self.processor = Processor(self)
 
-        for key in ["user_id", "market_id", "pg_config", "redis_config"]:
+        for key in ["user_id", "market_id"]:
             setattr(self, key, kwargs.get(key, None))
 
         self.report.log_market(self.args)
@@ -45,9 +44,9 @@ class Market:
         date = datetime.datetime.now()
         if self.args.report_path is not None:
             self.store_file_report(date)
-        if self.pg_config is not None and self.args.report_db:
+        if dbs.psql_connected() and self.args.report_db:
             self.store_database_report(date)
-        if self.redis_config is not None and self.args.report_redis:
+        if dbs.redis_connected() and self.args.report_redis:
             self.store_redis_report(date)
 
     def store_file_report(self, date):
@@ -64,29 +63,26 @@ class Market:
         try:
             report_query = 'INSERT INTO reports("date", "market_config_id", "debug") VALUES (%s, %s, %s) RETURNING id;'
             line_query = 'INSERT INTO report_lines("date", "report_id", "type", "payload") VALUES (%s, %s, %s, %s);'
-            connection = psycopg2.connect(**self.pg_config)
-            cursor = connection.cursor()
+            cursor = dbs.psql.cursor()
             cursor.execute(report_query, (date, self.market_id, self.debug))
             report_id = cursor.fetchone()[0]
             for date, type_, payload in self.report.to_json_array():
                 cursor.execute(line_query, (date, report_id, type_, payload))
 
-            connection.commit()
+            dbs.psql.commit()
             cursor.close()
-            connection.close()
         except Exception as e:
             print("impossible to store report to database: {}; {}".format(e.__class__.__name__, e))
 
     def store_redis_report(self, date):
         try:
-            conn = redis.Redis(**self.redis_config)
             for type_, log in self.report.to_json_redis():
                 key = "/cryptoportfolio/{}/{}/{}".format(self.market_id, date.isoformat(), type_)
-                conn.set(key, log, ex=31*24*60*60)
+                dbs.redis.set(key, log, ex=31*24*60*60)
                 key = "/cryptoportfolio/{}/latest/{}".format(self.market_id, type_)
-                conn.set(key, log)
+                dbs.redis.set(key, log)
             key = "/cryptoportfolio/{}/latest/date".format(self.market_id)
-            conn.set(key, date.isoformat())
+            dbs.redis.set(key, date.isoformat())
         except Exception as e:
             print("impossible to store report to redis: {}; {}".format(e.__class__.__name__, e))
 
@@ -259,6 +255,7 @@ class Processor:
                     "name": "print_balances",
                     "number": 1,
                     "fetch_balances": ["begin"],
+                    "fetch_balances_args": { "add_portfolio": True },
                     "print_tickers": { "base_currency": "BTC" },
                     }
                 ],
@@ -390,15 +387,19 @@ class Processor:
     def process_step(self, scenario_name, step, kwargs):
         process_name = "process_{}__{}_{}".format(scenario_name, step["number"], step["name"])
         self.market.report.log_stage("{}_begin".format(process_name))
+
+        fetch_args = step.get("fetch_balances_args", {})
         if "begin" in step.get("fetch_balances", []):
-            self.market.balances.fetch_balances(tag="{}_begin".format(process_name), log_tickers=True)
+            self.market.balances.fetch_balances(tag="{}_begin".format(process_name),
+                    log_tickers=True, **fetch_args)
 
         for action in self.ordered_actions:
             if action in step:
                 self.run_action(action, step[action], kwargs)
 
         if "end" in step.get("fetch_balances", []):
-            self.market.balances.fetch_balances(tag="{}_end".format(process_name), log_tickers=True)
+            self.market.balances.fetch_balances(tag="{}_end".format(process_name),
+                    log_tickers=True, **fetch_args)
         self.market.report.log_stage("{}_end".format(process_name))
 
     def method_arguments(self, action):
index cd0bf7babe2e7240588f4605571365f98096305b..76cfec88f4648e55cd9f1b084204e46f5ad730b4 100644 (file)
--- a/store.py
+++ b/store.py
@@ -7,6 +7,7 @@ import datetime
 import inspect
 from json import JSONDecodeError
 from simplejson.errors import JSONDecodeError as SimpleJSONDecodeError
+import dbs
 
 __all__ = ["Portfolio", "BalanceStore", "ReportStore", "TradeStore"]
 
@@ -302,13 +303,16 @@ class BalanceStore:
                 compute_value, type)
         return amounts
 
-    def fetch_balances(self, tag=None, log_tickers=False,
+    def fetch_balances(self, tag=None, add_portfolio=False, log_tickers=False,
             ticker_currency="BTC", ticker_compute_value="average", ticker_type="total"):
         all_balances = self.market.ccxt.fetch_all_balances()
         for currency, balance in all_balances.items():
             if balance["exchange_total"] != 0 or balance["margin_total"] != 0 or \
                     currency in self.all:
                 self.all[currency] = portfolio.Balance(currency, balance)
+        if add_portfolio:
+            for currency in Portfolio.repartition(from_cache=True):
+                self.all.setdefault(currency, portfolio.Balance(currency, {}))
         if log_tickers:
             tickers = self.in_currency(ticker_currency, compute_value=ticker_compute_value, type=ticker_type)
             self.market.report.log_balances(tag=tag,
@@ -508,7 +512,9 @@ class Portfolio:
             cls.get_cryptoportfolio(refetch=True)
 
     @classmethod
-    def repartition(cls, liquidity="medium"):
+    def repartition(cls, liquidity="medium", from_cache=False):
+        if from_cache:
+            cls.retrieve_cryptoportfolio()
         cls.get_cryptoportfolio()
         liquidities = cls.liquidities.get(liquidity)
         return liquidities[cls.last_date.get()]
@@ -530,11 +536,38 @@ class Portfolio:
         try:
             cls.data.set(r.json(parse_int=D, parse_float=D))
             cls.parse_cryptoportfolio()
+            cls.store_cryptoportfolio()
         except (JSONDecodeError, SimpleJSONDecodeError):
             cls.data.set(None)
             cls.last_date.set(None)
             cls.liquidities.set({})
 
+    @classmethod
+    def retrieve_cryptoportfolio(cls):
+        if dbs.redis_connected():
+            repartition = dbs.redis.get("/cryptoportfolio/repartition/latest")
+            date = dbs.redis.get("/cryptoportfolio/repartition/date")
+            if date is not None and repartition is not None:
+                date = datetime.datetime.strptime(date.decode(), "%Y-%m-%d")
+                repartition = json.loads(repartition, parse_int=D, parse_float=D)
+                repartition = { k: { date: v } for k, v in repartition.items() }
+
+                cls.data.set("")
+                cls.last_date.set(date)
+                cls.liquidities.set(repartition)
+
+    @classmethod
+    def store_cryptoportfolio(cls):
+        if dbs.redis_connected():
+            hash_ = {}
+            for liquidity, repartitions in cls.liquidities.items():
+                hash_[liquidity] = repartitions[cls.last_date.get()]
+            dump = json.dumps(hash_)
+            key = "/cryptoportfolio/repartition/latest"
+            dbs.redis.set(key, dump)
+            key = "/cryptoportfolio/repartition/date"
+            dbs.redis.set(key, cls.last_date.date().isoformat())
+
     @classmethod
     def parse_cryptoportfolio(cls):
         def filter_weights(weight_hash):
diff --git a/test.py b/test.py
index d7743b26ec11493b7cc62c51de698fae65842f2d..ed8943461fe07bae36a5c7fd87f0c13cf89ce01b 100644 (file)
--- a/test.py
+++ b/test.py
@@ -9,6 +9,7 @@ if "unit" in limits:
     from tests.test_market import *
     from tests.test_store import *
     from tests.test_portfolio import *
+    from tests.test_dbs import *
 
 if "acceptance" in limits:
     from tests.test_acceptance import *
index b85bf3ac58d8967edf0c7d691e3de2870b3f1374..935e0601f1ad9bbd82433ea841fd89474f1ad2c5 100644 (file)
@@ -4,7 +4,7 @@ from decimal import Decimal as D
 from unittest import mock
 import requests_mock
 from io import StringIO
-import portfolio, market, main, store
+import portfolio, market, main, store, dbs
 
 __all__ = ["limits", "unittest", "WebMockTestCase", "mock", "D",
         "StringIO"]
@@ -48,6 +48,8 @@ class WebMockTestCase(unittest.TestCase):
                     callback=None),
                 mock.patch.multiple(portfolio.Computation,
                     computations=portfolio.Computation.computations),
+                mock.patch.multiple(dbs,
+                    redis=None, psql=None)
                 ]
         for patcher in self.patchers:
             patcher.start()
diff --git a/tests/test_dbs.py b/tests/test_dbs.py
new file mode 100644 (file)
index 0000000..157c423
--- /dev/null
@@ -0,0 +1,108 @@
+from .helper import *
+import dbs, main
+
+@unittest.skipUnless("unit" in limits, "Unit skipped")
+class DbsTest(WebMockTestCase):
+    @mock.patch.object(dbs, "psycopg2")
+    def test_connect_psql(self, psycopg2):
+        args = main.configargparse.Namespace(**{
+            "db_host": "host",
+            "db_port": "port",
+            "db_user": "user",
+            "db_password": "password",
+            "db_database": "database",
+            })
+        psycopg2.connect.return_value = "pg_connection"
+        dbs.connect_psql(args)
+
+        psycopg2.connect.assert_called_once_with(host="host",
+                port="port", user="user", password="password",
+                database="database")
+        self.assertEqual("pg_connection", dbs.psql)
+        with self.assertRaises(AttributeError):
+            args.db_password
+
+        psycopg2.connect.reset_mock()
+        args = main.configargparse.Namespace(**{
+            "db_host": "host",
+            "db_port": "port",
+            "db_user": "user",
+            "db_password": "password",
+            "db_database": "database",
+            })
+        dbs.connect_psql(args)
+        psycopg2.connect.assert_not_called()
+
+    @mock.patch.object(dbs, "_redis")
+    def test_connect_redis(self, redis):
+        with self.subTest(redis_host="tcp"):
+            args = main.configargparse.Namespace(**{
+                "redis_host": "host",
+                "redis_port": "port",
+                "redis_database": "database",
+                })
+            redis.Redis.return_value = "redis_connection"
+            dbs.connect_redis(args)
+
+            redis.Redis.assert_called_once_with(host="host",
+                    port="port", db="database")
+            self.assertEqual("redis_connection", dbs.redis)
+            with self.assertRaises(AttributeError):
+                args.redis_database
+
+            redis.Redis.reset_mock()
+            args = main.configargparse.Namespace(**{
+                "redis_host": "host",
+                "redis_port": "port",
+                "redis_database": "database",
+                })
+            dbs.connect_redis(args)
+            redis.Redis.assert_not_called()
+
+        dbs.redis = None
+        with self.subTest(redis_host="socket"):
+            args = main.configargparse.Namespace(**{
+                "redis_host": "/run/foo",
+                "redis_port": "port",
+                "redis_database": "database",
+                })
+            redis.Redis.return_value = "redis_socket"
+            dbs.connect_redis(args)
+
+            redis.Redis.assert_called_once_with(unix_socket_path="/run/foo", db="database")
+            self.assertEqual("redis_socket", dbs.redis)
+
+    def test_redis_connected(self):
+        with self.subTest(redis=None):
+            dbs.redis = None
+            self.assertFalse(dbs.redis_connected())
+
+        with self.subTest(redis="mocked_true"):
+            dbs.redis = mock.Mock()
+            dbs.redis.ping.return_value = True
+            self.assertTrue(dbs.redis_connected())
+
+        with self.subTest(redis="mocked_false"):
+            dbs.redis = mock.Mock()
+            dbs.redis.ping.return_value = False
+            self.assertFalse(dbs.redis_connected())
+
+        with self.subTest(redis="mocked_raise"):
+            dbs.redis = mock.Mock()
+            dbs.redis.ping.side_effect = Exception("bouh")
+            self.assertFalse(dbs.redis_connected())
+
+    def test_psql_connected(self):
+        with self.subTest(psql=None):
+            dbs.psql = None
+            self.assertFalse(dbs.psql_connected())
+
+        with self.subTest(psql="connected"):
+            dbs.psql = mock.Mock()
+            dbs.psql.closed = 0
+            self.assertTrue(dbs.psql_connected())
+
+        with self.subTest(psql="not connected"):
+            dbs.psql = mock.Mock()
+            dbs.psql.closed = 3
+            self.assertFalse(dbs.psql_connected())
index 55b1382551aa34da39ec9b35fcd8b4bd7e8c7ee5..1864c062af75fd1bd7d7c0eabca80127e6a2dc13 100644 (file)
@@ -103,7 +103,6 @@ class MainTest(WebMockTestCase):
                 mock.patch("main.parse_config") as main_parse_config:
             with self.subTest(debug=False):
                 main_parse_args.return_value = self.market_args()
-                main_parse_config.return_value = ["pg_config", "redis_config"]
                 main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)]
                 m = main.get_user_market("config_path.ini", 1)
 
@@ -114,7 +113,6 @@ class MainTest(WebMockTestCase):
             main_parse_args.reset_mock()
             with self.subTest(debug=True):
                 main_parse_args.return_value = self.market_args(debug=True)
-                main_parse_config.return_value = ["pg_config", "redis_config"]
                 main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)]
                 m = main.get_user_market("config_path.ini", 1, debug=True)
 
@@ -135,16 +133,16 @@ class MainTest(WebMockTestCase):
             args_mock.after = "after"
             self.assertEqual("", stdout_mock.getvalue())
 
-            main.process("config", 3, 1, args_mock, "pg_config", "redis_config")
+            main.process("config", 3, 1, args_mock)
 
             market_mock.from_config.assert_has_calls([
-                mock.call("config", args_mock, pg_config="pg_config", redis_config="redis_config", market_id=3, user_id=1),
+                mock.call("config", args_mock, market_id=3, user_id=1),
                 mock.call().process("action", before="before", after="after"),
                 ])
 
             with self.subTest(exception=True):
                 market_mock.from_config.side_effect = Exception("boo")
-                main.process(3, "config", 1, args_mock, "pg_config", "redis_config")
+                main.process(3, "config", 1, args_mock)
                 self.assertEqual("Exception: boo\n", stdout_mock.getvalue())
 
     def test_main(self):
@@ -159,20 +157,18 @@ class MainTest(WebMockTestCase):
                 args_mock.user = "user"
                 parse_args.return_value = args_mock
 
-                parse_config.return_value = ["pg_config", "redis_config"]
-
                 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]
 
                 main.main(["Foo", "Bar"])
 
                 parse_args.assert_called_with(["Foo", "Bar"])
                 parse_config.assert_called_with(args_mock)
-                fetch_markets.assert_called_with("pg_config", "user")
+                fetch_markets.assert_called_with("user")
 
                 self.assertEqual(2, process.call_count)
                 process.assert_has_calls([
-                    mock.call("config1", 3, 1, args_mock, "pg_config", "redis_config"),
-                    mock.call("config2", 1, 2, args_mock, "pg_config", "redis_config"),
+                    mock.call("config1", 3, 1, args_mock),
+                    mock.call("config2", 1, 2, args_mock),
                     ])
         with self.subTest(parallel=True):
             with mock.patch("main.parse_args") as parse_args,\
@@ -187,24 +183,22 @@ class MainTest(WebMockTestCase):
                 args_mock.user = "user"
                 parse_args.return_value = args_mock
 
-                parse_config.return_value = ["pg_config", "redis_config"]
-
                 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]
 
                 main.main(["Foo", "Bar"])
 
                 parse_args.assert_called_with(["Foo", "Bar"])
                 parse_config.assert_called_with(args_mock)
-                fetch_markets.assert_called_with("pg_config", "user")
+                fetch_markets.assert_called_with("user")
 
                 stop.assert_called_once_with()
                 start.assert_called_once_with()
                 self.assertEqual(2, process.call_count)
                 process.assert_has_calls([
                     mock.call.__bool__(),
-                    mock.call("config1", 3, 1, args_mock, "pg_config", "redis_config"),
+                    mock.call("config1", 3, 1, args_mock),
                     mock.call.__bool__(),
-                    mock.call("config2", 1, 2, args_mock, "pg_config", "redis_config"),
+                    mock.call("config2", 1, 2, args_mock),
                     ])
         with self.subTest(quiet=True):
             with mock.patch("main.parse_args") as parse_args,\
@@ -219,8 +213,6 @@ class MainTest(WebMockTestCase):
                 args_mock.user = "user"
                 parse_args.return_value = args_mock
 
-                parse_config.return_value = ["pg_config", "redis_config"]
-
                 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]
 
                 main.main(["Foo", "Bar"])
@@ -240,8 +232,6 @@ class MainTest(WebMockTestCase):
                 args_mock.user = "user"
                 parse_args.return_value = args_mock
 
-                parse_config.return_value = ["pg_config", "redis_config"]
-
                 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]
 
                 main.main(["Foo", "Bar"])
@@ -252,63 +242,57 @@ class MainTest(WebMockTestCase):
     @mock.patch.object(main.sys, "exit")
     @mock.patch("main.os")
     def test_parse_config(self, os, exit):
-        with self.subTest(report_path=None):
+        with self.subTest(report_path=None),\
+                mock.patch.object(main.dbs, "connect_psql") as psql,\
+                mock.patch.object(main.dbs, "connect_redis") as redis:
             args = main.configargparse.Namespace(**{
                 "db_host": "host",
-                "db_port": "port",
-                "db_user": "user",
-                "db_password": "password",
-                "db_database": "database",
                 "redis_host": "rhost",
-                "redis_port": "rport",
-                "redis_database": "rdb",
                 "report_path": None,
                 })
 
-            db_config, redis_config = main.parse_config(args)
-            self.assertEqual({ "host": "host", "port": "port", "user":
-                "user", "password": "password", "database": "database"
-                }, db_config)
-            self.assertEqual({ "host": "rhost", "port": "rport", "db":
-                "rdb"}, redis_config)
+            main.parse_config(args)
+            psql.assert_called_once_with(args)
+            redis.assert_called_once_with(args)
+
+        with self.subTest(report_path=None, db=None),\
+                mock.patch.object(main.dbs, "connect_psql") as psql,\
+                mock.patch.object(main.dbs, "connect_redis") as redis:
+            args = main.configargparse.Namespace(**{
+                "db_host": None,
+                "redis_host": "rhost",
+                "report_path": None,
+                })
 
-            with self.assertRaises(AttributeError):
-                args.db_password
-            with self.assertRaises(AttributeError):
-                args.redis_host
+            main.parse_config(args)
+            psql.assert_not_called()
+            redis.assert_called_once_with(args)
 
-        with self.subTest(redis_host="socket"):
+        with self.subTest(report_path=None, redis=None),\
+                mock.patch.object(main.dbs, "connect_psql") as psql,\
+                mock.patch.object(main.dbs, "connect_redis") as redis:
             args = main.configargparse.Namespace(**{
                 "db_host": "host",
-                "db_port": "port",
-                "db_user": "user",
-                "db_password": "password",
-                "db_database": "database",
-                "redis_host": "/run/foo",
-                "redis_port": "rport",
-                "redis_database": "rdb",
+                "redis_host": None,
                 "report_path": None,
                 })
 
-            db_config, redis_config = main.parse_config(args)
-            self.assertEqual({ "unix_socket_path": "/run/foo", "db": "rdb"}, redis_config)
+            main.parse_config(args)
+            redis.assert_not_called()
+            psql.assert_called_once_with(args)
 
-        with self.subTest(report_path="present"):
+        with self.subTest(report_path="present"),\
+                mock.patch.object(main.dbs, "connect_psql") as psql,\
+                mock.patch.object(main.dbs, "connect_redis") as redis:
             args = main.configargparse.Namespace(**{
                 "db_host": "host",
-                "db_port": "port",
-                "db_user": "user",
-                "db_password": "password",
-                "db_database": "database",
                 "redis_host": "rhost",
-                "redis_port": "rport",
-                "redis_database": "rdb",
                 "report_path": "report_path",
                 })
 
             os.path.exists.return_value = False
 
-            result = main.parse_config(args)
+            main.parse_config(args)
 
             os.path.exists.assert_called_once_with("report_path")
             os.makedirs.assert_called_once_with("report_path")
@@ -331,29 +315,24 @@ class MainTest(WebMockTestCase):
                 mock.patch('sys.stderr', new_callable=StringIO) as stdout_mock:
             args = main.parse_args(["--config", "foo.bar"])
 
-    @mock.patch.object(main, "psycopg2")
-    def test_fetch_markets(self, psycopg2):
-        connect_mock = mock.Mock()
+    @mock.patch.object(main.dbs, "psql")
+    def test_fetch_markets(self, psql):
         cursor_mock = mock.MagicMock()
         cursor_mock.__iter__.return_value = ["row_1", "row_2"]
 
-        connect_mock.cursor.return_value = cursor_mock
-        psycopg2.connect.return_value = connect_mock
+        psql.cursor.return_value = cursor_mock
 
         with self.subTest(user=None):
-            rows = list(main.fetch_markets({"foo": "bar"}, None))
+            rows = list(main.fetch_markets(None))
 
-            psycopg2.connect.assert_called_once_with(foo="bar")
             cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs")
 
             self.assertEqual(["row_1", "row_2"], rows)
 
-        psycopg2.connect.reset_mock()
         cursor_mock.execute.reset_mock()
         with self.subTest(user=1):
-            rows = list(main.fetch_markets({"foo": "bar"}, 1))
+            rows = list(main.fetch_markets(1))
 
-            psycopg2.connect.assert_called_once_with(foo="bar")
             cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs WHERE user_id = %s", 1)
 
             self.assertEqual(["row_1", "row_2"], rows)
index 6a3322c579d43815fb15b8da538063427936a814..46fad53aab1344cdc71b35c68f86e60b32d1a85e 100644 (file)
@@ -1,5 +1,5 @@
 from .helper import *
-import market, store, portfolio
+import market, store, portfolio, dbs
 import datetime
 
 @unittest.skipUnless("unit" in limits, "Unit skipped")
@@ -595,13 +595,11 @@ class MarketTest(WebMockTestCase):
 
             self.assertRegex(stdout_mock.getvalue(), "impossible to store report file: FileNotFoundError;")
 
-    @mock.patch.object(market, "psycopg2")
-    def test_store_database_report(self, psycopg2):
-        connect_mock = mock.Mock()
+    @mock.patch.object(dbs, "psql")
+    def test_store_database_report(self, psql):
         cursor_mock = mock.MagicMock()
 
-        connect_mock.cursor.return_value = cursor_mock
-        psycopg2.connect.return_value = connect_mock
+        psql.cursor.return_value = cursor_mock
         m = market.Market(self.ccxt, self.market_args(),
                 pg_config={"config": "pg_config"}, user_id=1)
         cursor_mock.fetchone.return_value = [42]
@@ -613,7 +611,7 @@ class MarketTest(WebMockTestCase):
                     ("date2", "type2", "payload2"),
                     ]
             m.store_database_report(datetime.datetime(2018, 3, 24))
-            connect_mock.assert_has_calls([
+            psql.assert_has_calls([
                 mock.call.cursor(),
                 mock.call.cursor().execute('INSERT INTO reports("date", "market_config_id", "debug") VALUES (%s, %s, %s) RETURNING id;', (datetime.datetime(2018, 3, 24), None, False)),
                 mock.call.cursor().fetchone(),
@@ -621,21 +619,16 @@ class MarketTest(WebMockTestCase):
                 mock.call.cursor().execute('INSERT INTO report_lines("date", "report_id", "type", "payload") VALUES (%s, %s, %s, %s);', ('date2', 42, 'type2', 'payload2')),
                 mock.call.commit(),
                 mock.call.cursor().close(),
-                mock.call.close()
                 ])
 
-        connect_mock.reset_mock()
         with self.subTest(error=True),\
                 mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock:
-            psycopg2.connect.side_effect = Exception("Bouh")
+            psql.cursor.side_effect = Exception("Bouh")
             m.store_database_report(datetime.datetime(2018, 3, 24))
             self.assertEqual(stdout_mock.getvalue(), "impossible to store report to database: Exception; Bouh\n")
 
-    @mock.patch.object(market, "redis")
+    @mock.patch.object(dbs, "redis")
     def test_store_redis_report(self, redis):
-        connect_mock = mock.Mock()
-        redis.Redis.return_value = connect_mock
-
         m = market.Market(self.ccxt, self.market_args(),
                 redis_config={"config": "redis_config"}, market_id=1)
 
@@ -646,7 +639,7 @@ class MarketTest(WebMockTestCase):
                     ("type2", "payload2"),
                     ]
             m.store_redis_report(datetime.datetime(2018, 3, 24))
-            connect_mock.assert_has_calls([
+            redis.assert_has_calls([
                 mock.call.set("/cryptoportfolio/1/2018-03-24T00:00:00/type1", "payload1", ex=31*24*60*60),
                 mock.call.set("/cryptoportfolio/1/latest/type1", "payload1"),
                 mock.call.set("/cryptoportfolio/1/2018-03-24T00:00:00/type2", "payload2", ex=31*24*60*60),
@@ -654,20 +647,24 @@ class MarketTest(WebMockTestCase):
                 mock.call.set("/cryptoportfolio/1/latest/date", "2018-03-24T00:00:00"),
                 ])
 
-        connect_mock.reset_mock()
+        redis.reset_mock()
         with self.subTest(error=True),\
                 mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock:
-            redis.Redis.side_effect = Exception("Bouh")
+            redis.set.side_effect = Exception("Bouh")
             m.store_redis_report(datetime.datetime(2018, 3, 24))
             self.assertEqual(stdout_mock.getvalue(), "impossible to store report to redis: Exception; Bouh\n")
 
     def test_store_report(self):
         m = market.Market(self.ccxt, self.market_args(report_db=False), user_id=1)
-        with self.subTest(file=None, pg_config=None),\
+        with self.subTest(file=None, pg_connected=None),\
+                mock.patch.object(dbs, "psql_connected") as psql,\
+                mock.patch.object(dbs, "redis_connected") as redis,\
                 mock.patch.object(m, "report") as report,\
                 mock.patch.object(m, "store_database_report") as db_report,\
                 mock.patch.object(m, "store_redis_report") as redis_report,\
                 mock.patch.object(m, "store_file_report") as file_report:
+            psql.return_value = False
+            redis.return_value = False
             m.store_report()
             report.merge.assert_called_with(store.Portfolio.report)
 
@@ -677,13 +674,16 @@ class MarketTest(WebMockTestCase):
 
         report.reset_mock()
         m = market.Market(self.ccxt, self.market_args(report_db=False, report_path="present"), user_id=1)
-        with self.subTest(file="present", pg_config=None),\
+        with self.subTest(file="present", pg_connected=None),\
+                mock.patch.object(dbs, "psql_connected") as psql,\
+                mock.patch.object(dbs, "redis_connected") as redis,\
                 mock.patch.object(m, "report") as report,\
                 mock.patch.object(m, "store_file_report") as file_report,\
                 mock.patch.object(m, "store_redis_report") as redis_report,\
                 mock.patch.object(m, "store_database_report") as db_report,\
                 mock.patch.object(market.datetime, "datetime") as time_mock:
-
+            psql.return_value = False
+            redis.return_value = False
             time_mock.now.return_value = datetime.datetime(2018, 2, 25)
 
             m.store_report()
@@ -695,13 +695,16 @@ class MarketTest(WebMockTestCase):
 
         report.reset_mock()
         m = market.Market(self.ccxt, self.market_args(report_db=True, report_path="present"), user_id=1)
-        with self.subTest(file="present", pg_config=None, report_db=True),\
+        with self.subTest(file="present", pg_connected=None, report_db=True),\
+                mock.patch.object(dbs, "psql_connected") as psql,\
+                mock.patch.object(dbs, "redis_connected") as redis,\
                 mock.patch.object(m, "report") as report,\
                 mock.patch.object(m, "store_file_report") as file_report,\
                 mock.patch.object(m, "store_redis_report") as redis_report,\
                 mock.patch.object(m, "store_database_report") as db_report,\
                 mock.patch.object(market.datetime, "datetime") as time_mock:
-
+            psql.return_value = False
+            redis.return_value = False
             time_mock.now.return_value = datetime.datetime(2018, 2, 25)
 
             m.store_report()
@@ -712,14 +715,17 @@ class MarketTest(WebMockTestCase):
             redis_report.assert_not_called()
 
         report.reset_mock()
-        m = market.Market(self.ccxt, self.market_args(report_db=True), pg_config="present", user_id=1)
-        with self.subTest(file=None, pg_config="present"),\
+        m = market.Market(self.ccxt, self.market_args(report_db=True), user_id=1)
+        with self.subTest(file=None, pg_connected=True),\
+                mock.patch.object(dbs, "psql_connected") as psql,\
+                mock.patch.object(dbs, "redis_connected") as redis,\
                 mock.patch.object(m, "report") as report,\
                 mock.patch.object(m, "store_file_report") as file_report,\
                 mock.patch.object(m, "store_redis_report") as redis_report,\
                 mock.patch.object(m, "store_database_report") as db_report,\
                 mock.patch.object(market.datetime, "datetime") as time_mock:
-
+            psql.return_value = True
+            redis.return_value = False
             time_mock.now.return_value = datetime.datetime(2018, 2, 25)
 
             m.store_report()
@@ -731,14 +737,17 @@ class MarketTest(WebMockTestCase):
 
         report.reset_mock()
         m = market.Market(self.ccxt, self.market_args(report_db=True, report_path="present"),
-                pg_config="pg_config", user_id=1)
-        with self.subTest(file="present", pg_config="present"),\
+                user_id=1)
+        with self.subTest(file="present", pg_connected=True),\
+                mock.patch.object(dbs, "psql_connected") as psql,\
+                mock.patch.object(dbs, "redis_connected") as redis,\
                 mock.patch.object(m, "report") as report,\
                 mock.patch.object(m, "store_file_report") as file_report,\
                 mock.patch.object(m, "store_redis_report") as redis_report,\
                 mock.patch.object(m, "store_database_report") as db_report,\
                 mock.patch.object(market.datetime, "datetime") as time_mock:
-
+            psql.return_value = True
+            redis.return_value = False
             time_mock.now.return_value = datetime.datetime(2018, 2, 25)
 
             m.store_report()
@@ -750,14 +759,17 @@ class MarketTest(WebMockTestCase):
 
         report.reset_mock()
         m = market.Market(self.ccxt, self.market_args(report_redis=False),
-                redis_config="redis_config", user_id=1)
-        with self.subTest(redis_config="present", report_redis=False),\
+                user_id=1)
+        with self.subTest(redis_connected=True, report_redis=False),\
+                mock.patch.object(dbs, "psql_connected") as psql,\
+                mock.patch.object(dbs, "redis_connected") as redis,\
                 mock.patch.object(m, "report") as report,\
                 mock.patch.object(m, "store_file_report") as file_report,\
                 mock.patch.object(m, "store_redis_report") as redis_report,\
                 mock.patch.object(m, "store_database_report") as db_report,\
                 mock.patch.object(market.datetime, "datetime") as time_mock:
-
+            psql.return_value = False
+            redis.return_value = True
             time_mock.now.return_value = datetime.datetime(2018, 2, 25)
 
             m.store_report()
@@ -766,13 +778,16 @@ class MarketTest(WebMockTestCase):
         report.reset_mock()
         m = market.Market(self.ccxt, self.market_args(report_redis=True),
                 user_id=1)
-        with self.subTest(redis_config="absent", report_redis=True),\
+        with self.subTest(redis_connected=False, report_redis=True),\
+                mock.patch.object(dbs, "psql_connected") as psql,\
+                mock.patch.object(dbs, "redis_connected") as redis,\
                 mock.patch.object(m, "report") as report,\
                 mock.patch.object(m, "store_file_report") as file_report,\
                 mock.patch.object(m, "store_redis_report") as redis_report,\
                 mock.patch.object(m, "store_database_report") as db_report,\
                 mock.patch.object(market.datetime, "datetime") as time_mock:
-
+            psql.return_value = False
+            redis.return_value = False
             time_mock.now.return_value = datetime.datetime(2018, 2, 25)
 
             m.store_report()
@@ -780,14 +795,17 @@ class MarketTest(WebMockTestCase):
 
         report.reset_mock()
         m = market.Market(self.ccxt, self.market_args(report_redis=True),
-                redis_config="redis_config", user_id=1)
-        with self.subTest(redis_config="present", report_redis=True),\
+                user_id=1)
+        with self.subTest(redis_connected=True, report_redis=True),\
+                mock.patch.object(dbs, "psql_connected") as psql,\
+                mock.patch.object(dbs, "redis_connected") as redis,\
                 mock.patch.object(m, "report") as report,\
                 mock.patch.object(m, "store_file_report") as file_report,\
                 mock.patch.object(m, "store_redis_report") as redis_report,\
                 mock.patch.object(m, "store_database_report") as db_report,\
                 mock.patch.object(market.datetime, "datetime") as time_mock:
-
+            psql.return_value = False
+            redis.return_value = True
             time_mock.now.return_value = datetime.datetime(2018, 2, 25)
 
             m.store_report()
@@ -1014,6 +1032,15 @@ class ProcessorTest(WebMockTestCase):
             processor.process_step("foo", step, {"foo":"bar"})
             self.m.balances.fetch_balances.assert_not_called()
 
+        self.m.reset_mock()
+        with mock.patch.object(processor, "run_action") as run_action:
+            step = processor.scenarios["print_balances"][0]
+
+            processor.process_step("foo", step, {"foo":"bar"})
+            self.m.balances.fetch_balances.assert_called_once_with(
+                    add_portfolio=True, log_tickers=True,
+                    tag='process_foo__1_print_balances_begin')
+
     def test_parse_args(self):
         processor = market.Processor(self.m)
 
index 12999d36e1b7e9bd64f0d7eda8ad002dd14b8411..ee7e06349c4816263728b5c0d0504bb397e5018e 100644 (file)
@@ -391,6 +391,18 @@ class BalanceStoreTest(WebMockTestCase):
                     tag=None, ticker_currency='FOO', tickers='tickers',
                     type='type')
 
+        balance_store = market.BalanceStore(self.m)
+
+        with self.subTest(add_portfolio=True),\
+                mock.patch.object(market.Portfolio, "repartition") as repartition:
+            repartition.return_value = {
+                    "DOGE": D("0.5"),
+                    "USDT": D("0.5"),
+                    }
+            balance_store.fetch_balances(add_portfolio=True)
+            self.assertListEqual(["USDT", "XVG", "XMR", "DOGE"], list(balance_store.currencies()))
+
+
     @mock.patch.object(market.Portfolio, "repartition")
     def test_dispatch_assets(self, repartition):
         self.m.ccxt.fetch_all_balances.return_value = self.fetch_balance
@@ -1101,7 +1113,8 @@ class PortfolioTest(WebMockTestCase):
         self.wm.get(market.Portfolio.URL, text=self.json_response)
 
     @mock.patch.object(market.Portfolio, "parse_cryptoportfolio")
-    def test_get_cryptoportfolio(self, parse_cryptoportfolio):
+    @mock.patch.object(market.Portfolio, "store_cryptoportfolio")
+    def test_get_cryptoportfolio(self, store_cryptoportfolio, parse_cryptoportfolio):
         with self.subTest(parallel=False):
             self.wm.get(market.Portfolio.URL, [
                 {"text":'{ "foo": "bar" }', "status_code": 200},
@@ -1116,23 +1129,28 @@ class PortfolioTest(WebMockTestCase):
             market.Portfolio.report.log_error.assert_not_called()
             market.Portfolio.report.log_http_request.assert_called_once()
             parse_cryptoportfolio.assert_called_once_with()
+            store_cryptoportfolio.assert_called_once_with()
             market.Portfolio.report.log_http_request.reset_mock()
             parse_cryptoportfolio.reset_mock()
+            store_cryptoportfolio.reset_mock()
             market.Portfolio.data = store.LockedVar(None)
 
             market.Portfolio.get_cryptoportfolio()
             self.assertIsNone(market.Portfolio.data.get())
             self.assertEqual(2, self.wm.call_count)
             parse_cryptoportfolio.assert_not_called()
+            store_cryptoportfolio.assert_not_called()
             market.Portfolio.report.log_error.assert_not_called()
             market.Portfolio.report.log_http_request.assert_called_once()
             market.Portfolio.report.log_http_request.reset_mock()
             parse_cryptoportfolio.reset_mock()
+            store_cryptoportfolio.reset_mock()
 
             market.Portfolio.data = store.LockedVar("Foo")
             market.Portfolio.get_cryptoportfolio()
             self.assertEqual(2, self.wm.call_count)
             parse_cryptoportfolio.assert_not_called()
+            store_cryptoportfolio.assert_not_called()
 
             market.Portfolio.get_cryptoportfolio(refetch=True)
             self.assertEqual("Foo", market.Portfolio.data.get())
@@ -1153,6 +1171,7 @@ class PortfolioTest(WebMockTestCase):
                     market.Portfolio.get_cryptoportfolio()
                     self.assertIn("foo", market.Portfolio.data.get())
                 parse_cryptoportfolio.reset_mock()
+                store_cryptoportfolio.reset_mock()
                 with self.subTest(worker=False):
                     market.Portfolio.data = store.LockedVar(None)
                     market.Portfolio.worker = mock.Mock()
@@ -1160,6 +1179,7 @@ class PortfolioTest(WebMockTestCase):
                     market.Portfolio.get_cryptoportfolio()
                     notify.assert_called_once_with()
                     parse_cryptoportfolio.assert_not_called()
+                    store_cryptoportfolio.assert_not_called()
 
     def test_parse_cryptoportfolio(self):
         with self.subTest(description="Normal case"):
@@ -1223,25 +1243,95 @@ class PortfolioTest(WebMockTestCase):
             self.assertEqual({}, market.Portfolio.liquidities.get("high"))
             self.assertEqual(datetime.datetime(1,1,1), market.Portfolio.last_date.get())
 
-
-    @mock.patch.object(market.Portfolio, "get_cryptoportfolio")
-    def test_repartition(self, get_cryptoportfolio):
-        market.Portfolio.liquidities = store.LockedVar({
+    @mock.patch.object(store.dbs, "redis_connected")
+    @mock.patch.object(store.dbs, "redis")
+    def test_store_cryptoportfolio(self, redis, redis_connected):
+        store.Portfolio.liquidities = store.LockedVar({
                 "medium": {
-                    "2018-03-01": "medium_2018-03-01",
-                    "2018-03-08": "medium_2018-03-08",
+                    datetime.datetime(2018,3,1): "medium_2018-03-01",
+                    datetime.datetime(2018,3,8): "medium_2018-03-08",
                     },
                 "high": {
-                    "2018-03-01": "high_2018-03-01",
-                    "2018-03-08": "high_2018-03-08",
+                    datetime.datetime(2018,3,1): "high_2018-03-01",
+                    datetime.datetime(2018,3,8): "high_2018-03-08",
                     }
                 })
-        market.Portfolio.last_date = store.LockedVar("2018-03-08")
+        store.Portfolio.last_date = store.LockedVar(datetime.datetime(2018,3,8))
+
+        with self.subTest(redis_connected=False):
+            redis_connected.return_value = False
+            store.Portfolio.store_cryptoportfolio()
+            redis.set.assert_not_called()
+
+        with self.subTest(redis_connected=True):
+            redis_connected.return_value = True
+            store.Portfolio.store_cryptoportfolio()
+            redis.set.assert_has_calls([
+                mock.call("/cryptoportfolio/repartition/latest", '{"medium": "medium_2018-03-08", "high": "high_2018-03-08"}'),
+                mock.call("/cryptoportfolio/repartition/date", "2018-03-08"),
+                ])
+
+    @mock.patch.object(store.dbs, "redis_connected")
+    @mock.patch.object(store.dbs, "redis")
+    def test_retrieve_cryptoportfolio(self, redis, redis_connected):
+        with self.subTest(redis_connected=False):
+            redis_connected.return_value = False
+            store.Portfolio.retrieve_cryptoportfolio()
+            redis.get.assert_not_called()
+            self.assertIsNone(store.Portfolio.data.get())
+
+        with self.subTest(redis_connected=True, value=None):
+            redis_connected.return_value = True
+            redis.get.return_value = None
+            store.Portfolio.retrieve_cryptoportfolio()
+            self.assertEqual(2, redis.get.call_count)
+
+        redis.reset_mock()
+        with self.subTest(redis_connected=True, value="present"):
+            redis_connected.return_value = True
+            redis.get.side_effect = [
+                    b'{ "medium": "medium_repartition", "high": "high_repartition" }',
+                    b"2018-03-08"
+                    ]
+            store.Portfolio.retrieve_cryptoportfolio()
+            self.assertEqual(2, redis.get.call_count)
+            self.assertEqual(datetime.datetime(2018,3,8), store.Portfolio.last_date.get())
+            self.assertEqual("", store.Portfolio.data.get())
+            expected_liquidities = {
+                    'high': { datetime.datetime(2018, 3, 8): 'high_repartition' },
+                    'medium': { datetime.datetime(2018, 3, 8): 'medium_repartition' },
+                    }
+            self.assertEqual(expected_liquidities, store.Portfolio.liquidities.get())
+
+    @mock.patch.object(market.Portfolio, "get_cryptoportfolio")
+    @mock.patch.object(market.Portfolio, "retrieve_cryptoportfolio")
+    def test_repartition(self, retrieve_cryptoportfolio, get_cryptoportfolio):
+        with self.subTest(from_cache=False):
+            market.Portfolio.liquidities = store.LockedVar({
+                    "medium": {
+                        "2018-03-01": "medium_2018-03-01",
+                        "2018-03-08": "medium_2018-03-08",
+                        },
+                    "high": {
+                        "2018-03-01": "high_2018-03-01",
+                        "2018-03-08": "high_2018-03-08",
+                        }
+                    })
+            market.Portfolio.last_date = store.LockedVar("2018-03-08")
+
+            self.assertEqual("medium_2018-03-08", market.Portfolio.repartition())
+            get_cryptoportfolio.assert_called_once_with()
+            retrieve_cryptoportfolio.assert_not_called()
+            self.assertEqual("medium_2018-03-08", market.Portfolio.repartition(liquidity="medium"))
+            self.assertEqual("high_2018-03-08", market.Portfolio.repartition(liquidity="high"))
+
+        retrieve_cryptoportfolio.reset_mock()
+        get_cryptoportfolio.reset_mock()
 
-        self.assertEqual("medium_2018-03-08", market.Portfolio.repartition())
-        get_cryptoportfolio.assert_called_once_with()
-        self.assertEqual("medium_2018-03-08", market.Portfolio.repartition(liquidity="medium"))
-        self.assertEqual("high_2018-03-08", market.Portfolio.repartition(liquidity="high"))
+        with self.subTest(from_cache=True):
+            self.assertEqual("medium_2018-03-08", market.Portfolio.repartition(from_cache=True))
+            get_cryptoportfolio.assert_called_once_with()
+            retrieve_cryptoportfolio.assert_called_once_with()
 
     @mock.patch.object(market.time, "sleep")
     @mock.patch.object(market.Portfolio, "get_cryptoportfolio")