aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--dbs.py55
-rw-r--r--main.py55
-rw-r--r--market.py31
-rw-r--r--store.py37
-rw-r--r--test.py1
-rw-r--r--tests/helper.py4
-rw-r--r--tests/test_dbs.py108
-rw-r--r--tests/test_main.py107
-rw-r--r--tests/test_market.py99
-rw-r--r--tests/test_store.py118
10 files changed, 443 insertions, 172 deletions
diff --git a/dbs.py b/dbs.py
new file mode 100644
index 0000000..b32afa3
--- /dev/null
+++ b/dbs.py
@@ -0,0 +1,55 @@
1import psycopg2
2import redis as _redis
3
4redis = None
5psql = None
6
7def redis_connected():
8 global redis
9 if redis is None:
10 return False
11 else:
12 try:
13 return redis.ping()
14 except Exception:
15 return False
16
17def psql_connected():
18 global psql
19 return psql is not None and psql.closed == 0
20
21def connect_redis(args):
22 global redis
23
24 redis_config = {
25 "host": args.redis_host,
26 "port": args.redis_port,
27 "db": args.redis_database,
28 }
29 if redis_config["host"].startswith("/"):
30 redis_config["unix_socket_path"] = redis_config.pop("host")
31 del(redis_config["port"])
32 del(args.redis_host)
33 del(args.redis_port)
34 del(args.redis_database)
35
36 if redis is None:
37 redis = _redis.Redis(**redis_config)
38
39def connect_psql(args):
40 global psql
41 pg_config = {
42 "host": args.db_host,
43 "port": args.db_port,
44 "user": args.db_user,
45 "password": args.db_password,
46 "database": args.db_database,
47 }
48 del(args.db_host)
49 del(args.db_port)
50 del(args.db_user)
51 del(args.db_password)
52 del(args.db_database)
53
54 if psql is None:
55 psql = psycopg2.connect(**pg_config)
diff --git a/main.py b/main.py
index a461207..ee25182 100644
--- a/main.py
+++ b/main.py
@@ -1,5 +1,5 @@
1import configargparse 1import configargparse
2import psycopg2 2import dbs
3import os 3import os
4import sys 4import sys
5 5
@@ -63,15 +63,12 @@ def get_user_market(config_path, user_id, debug=False):
63 if debug: 63 if debug:
64 args.append("--debug") 64 args.append("--debug")
65 args = parse_args(args) 65 args = parse_args(args)
66 pg_config, redis_config = parse_config(args) 66 parse_config(args)
67 market_id, market_config, user_id = list(fetch_markets(pg_config, str(user_id)))[0] 67 market_id, market_config, user_id = list(fetch_markets(str(user_id)))[0]
68 return market.Market.from_config(market_config, args, 68 return market.Market.from_config(market_config, args, user_id=user_id)
69 pg_config=pg_config, market_id=market_id,
70 user_id=user_id)
71 69
72def fetch_markets(pg_config, user): 70def fetch_markets(user):
73 connection = psycopg2.connect(**pg_config) 71 cursor = dbs.psql.cursor()
74 cursor = connection.cursor()
75 72
76 if user is None: 73 if user is None:
77 cursor.execute("SELECT id,config,user_id FROM market_configs") 74 cursor.execute("SELECT id,config,user_id FROM market_configs")
@@ -82,30 +79,11 @@ def fetch_markets(pg_config, user):
82 yield row 79 yield row
83 80
84def parse_config(args): 81def parse_config(args):
85 pg_config = { 82 if args.db_host is not None:
86 "host": args.db_host, 83 dbs.connect_psql(args)
87 "port": args.db_port, 84
88 "user": args.db_user, 85 if args.redis_host is not None:
89 "password": args.db_password, 86 dbs.connect_redis(args)
90 "database": args.db_database,
91 }
92 del(args.db_host)
93 del(args.db_port)
94 del(args.db_user)
95 del(args.db_password)
96 del(args.db_database)
97
98 redis_config = {
99 "host": args.redis_host,
100 "port": args.redis_port,
101 "db": args.redis_database,
102 }
103 if redis_config["host"].startswith("/"):
104 redis_config["unix_socket_path"] = redis_config.pop("host")
105 del(redis_config["port"])
106 del(args.redis_host)
107 del(args.redis_port)
108 del(args.redis_database)
109 87
110 report_path = args.report_path 88 report_path = args.report_path
111 89
@@ -113,8 +91,6 @@ def parse_config(args):
113 os.path.exists(report_path): 91 os.path.exists(report_path):
114 os.makedirs(report_path) 92 os.makedirs(report_path)
115 93
116 return pg_config, redis_config
117
118def parse_args(argv): 94def parse_args(argv):
119 parser = configargparse.ArgumentParser( 95 parser = configargparse.ArgumentParser(
120 description="Run the trade bot.") 96 description="Run the trade bot.")
@@ -176,11 +152,10 @@ def parse_args(argv):
176 parsed.action = ["sell_all"] 152 parsed.action = ["sell_all"]
177 return parsed 153 return parsed
178 154
179def process(market_config, market_id, user_id, args, pg_config, redis_config): 155def process(market_config, market_id, user_id, args):
180 try: 156 try:
181 market.Market\ 157 market.Market\
182 .from_config(market_config, args, market_id=market_id, 158 .from_config(market_config, args, market_id=market_id,
183 pg_config=pg_config, redis_config=redis_config,
184 user_id=user_id)\ 159 user_id=user_id)\
185 .process(args.action, before=args.before, after=args.after) 160 .process(args.action, before=args.before, after=args.after)
186 except Exception as e: 161 except Exception as e:
@@ -189,7 +164,7 @@ def process(market_config, market_id, user_id, args, pg_config, redis_config):
189def main(argv): 164def main(argv):
190 args = parse_args(argv) 165 args = parse_args(argv)
191 166
192 pg_config, redis_config = parse_config(args) 167 parse_config(args)
193 168
194 market.Portfolio.report.set_verbose(not args.quiet) 169 market.Portfolio.report.set_verbose(not args.quiet)
195 170
@@ -205,8 +180,8 @@ def main(argv):
205 else: 180 else:
206 process_ = process 181 process_ = process
207 182
208 for market_id, market_config, user_id in fetch_markets(pg_config, args.user): 183 for market_id, market_config, user_id in fetch_markets(args.user):
209 process_(market_config, market_id, user_id, args, pg_config, redis_config) 184 process_(market_config, market_id, user_id, args)
210 185
211 if args.parallel: 186 if args.parallel:
212 for thread in threads: 187 for thread in threads:
diff --git a/market.py b/market.py
index fc6f9f6..5876071 100644
--- a/market.py
+++ b/market.py
@@ -1,8 +1,7 @@
1from ccxt import ExchangeError, NotSupported, RequestTimeout, InvalidNonce 1from ccxt import ExchangeError, NotSupported, RequestTimeout, InvalidNonce
2import ccxt_wrapper as ccxt 2import ccxt_wrapper as ccxt
3import time 3import time
4import psycopg2 4import dbs
5import redis
6from store import * 5from store import *
7from cachetools.func import ttl_cache 6from cachetools.func import ttl_cache
8from datetime import datetime 7from datetime import datetime
@@ -27,7 +26,7 @@ class Market:
27 self.balances = BalanceStore(self) 26 self.balances = BalanceStore(self)
28 self.processor = Processor(self) 27 self.processor = Processor(self)
29 28
30 for key in ["user_id", "market_id", "pg_config", "redis_config"]: 29 for key in ["user_id", "market_id"]:
31 setattr(self, key, kwargs.get(key, None)) 30 setattr(self, key, kwargs.get(key, None))
32 31
33 self.report.log_market(self.args) 32 self.report.log_market(self.args)
@@ -45,9 +44,9 @@ class Market:
45 date = datetime.datetime.now() 44 date = datetime.datetime.now()
46 if self.args.report_path is not None: 45 if self.args.report_path is not None:
47 self.store_file_report(date) 46 self.store_file_report(date)
48 if self.pg_config is not None and self.args.report_db: 47 if dbs.psql_connected() and self.args.report_db:
49 self.store_database_report(date) 48 self.store_database_report(date)
50 if self.redis_config is not None and self.args.report_redis: 49 if dbs.redis_connected() and self.args.report_redis:
51 self.store_redis_report(date) 50 self.store_redis_report(date)
52 51
53 def store_file_report(self, date): 52 def store_file_report(self, date):
@@ -64,29 +63,26 @@ class Market:
64 try: 63 try:
65 report_query = 'INSERT INTO reports("date", "market_config_id", "debug") VALUES (%s, %s, %s) RETURNING id;' 64 report_query = 'INSERT INTO reports("date", "market_config_id", "debug") VALUES (%s, %s, %s) RETURNING id;'
66 line_query = 'INSERT INTO report_lines("date", "report_id", "type", "payload") VALUES (%s, %s, %s, %s);' 65 line_query = 'INSERT INTO report_lines("date", "report_id", "type", "payload") VALUES (%s, %s, %s, %s);'
67 connection = psycopg2.connect(**self.pg_config) 66 cursor = dbs.psql.cursor()
68 cursor = connection.cursor()
69 cursor.execute(report_query, (date, self.market_id, self.debug)) 67 cursor.execute(report_query, (date, self.market_id, self.debug))
70 report_id = cursor.fetchone()[0] 68 report_id = cursor.fetchone()[0]
71 for date, type_, payload in self.report.to_json_array(): 69 for date, type_, payload in self.report.to_json_array():
72 cursor.execute(line_query, (date, report_id, type_, payload)) 70 cursor.execute(line_query, (date, report_id, type_, payload))
73 71
74 connection.commit() 72 dbs.psql.commit()
75 cursor.close() 73 cursor.close()
76 connection.close()
77 except Exception as e: 74 except Exception as e:
78 print("impossible to store report to database: {}; {}".format(e.__class__.__name__, e)) 75 print("impossible to store report to database: {}; {}".format(e.__class__.__name__, e))
79 76
80 def store_redis_report(self, date): 77 def store_redis_report(self, date):
81 try: 78 try:
82 conn = redis.Redis(**self.redis_config)
83 for type_, log in self.report.to_json_redis(): 79 for type_, log in self.report.to_json_redis():
84 key = "/cryptoportfolio/{}/{}/{}".format(self.market_id, date.isoformat(), type_) 80 key = "/cryptoportfolio/{}/{}/{}".format(self.market_id, date.isoformat(), type_)
85 conn.set(key, log, ex=31*24*60*60) 81 dbs.redis.set(key, log, ex=31*24*60*60)
86 key = "/cryptoportfolio/{}/latest/{}".format(self.market_id, type_) 82 key = "/cryptoportfolio/{}/latest/{}".format(self.market_id, type_)
87 conn.set(key, log) 83 dbs.redis.set(key, log)
88 key = "/cryptoportfolio/{}/latest/date".format(self.market_id) 84 key = "/cryptoportfolio/{}/latest/date".format(self.market_id)
89 conn.set(key, date.isoformat()) 85 dbs.redis.set(key, date.isoformat())
90 except Exception as e: 86 except Exception as e:
91 print("impossible to store report to redis: {}; {}".format(e.__class__.__name__, e)) 87 print("impossible to store report to redis: {}; {}".format(e.__class__.__name__, e))
92 88
@@ -259,6 +255,7 @@ class Processor:
259 "name": "print_balances", 255 "name": "print_balances",
260 "number": 1, 256 "number": 1,
261 "fetch_balances": ["begin"], 257 "fetch_balances": ["begin"],
258 "fetch_balances_args": { "add_portfolio": True },
262 "print_tickers": { "base_currency": "BTC" }, 259 "print_tickers": { "base_currency": "BTC" },
263 } 260 }
264 ], 261 ],
@@ -390,15 +387,19 @@ class Processor:
390 def process_step(self, scenario_name, step, kwargs): 387 def process_step(self, scenario_name, step, kwargs):
391 process_name = "process_{}__{}_{}".format(scenario_name, step["number"], step["name"]) 388 process_name = "process_{}__{}_{}".format(scenario_name, step["number"], step["name"])
392 self.market.report.log_stage("{}_begin".format(process_name)) 389 self.market.report.log_stage("{}_begin".format(process_name))
390
391 fetch_args = step.get("fetch_balances_args", {})
393 if "begin" in step.get("fetch_balances", []): 392 if "begin" in step.get("fetch_balances", []):
394 self.market.balances.fetch_balances(tag="{}_begin".format(process_name), log_tickers=True) 393 self.market.balances.fetch_balances(tag="{}_begin".format(process_name),
394 log_tickers=True, **fetch_args)
395 395
396 for action in self.ordered_actions: 396 for action in self.ordered_actions:
397 if action in step: 397 if action in step:
398 self.run_action(action, step[action], kwargs) 398 self.run_action(action, step[action], kwargs)
399 399
400 if "end" in step.get("fetch_balances", []): 400 if "end" in step.get("fetch_balances", []):
401 self.market.balances.fetch_balances(tag="{}_end".format(process_name), log_tickers=True) 401 self.market.balances.fetch_balances(tag="{}_end".format(process_name),
402 log_tickers=True, **fetch_args)
402 self.market.report.log_stage("{}_end".format(process_name)) 403 self.market.report.log_stage("{}_end".format(process_name))
403 404
404 def method_arguments(self, action): 405 def method_arguments(self, action):
diff --git a/store.py b/store.py
index cd0bf7b..76cfec8 100644
--- a/store.py
+++ b/store.py
@@ -7,6 +7,7 @@ import datetime
7import inspect 7import inspect
8from json import JSONDecodeError 8from json import JSONDecodeError
9from simplejson.errors import JSONDecodeError as SimpleJSONDecodeError 9from simplejson.errors import JSONDecodeError as SimpleJSONDecodeError
10import dbs
10 11
11__all__ = ["Portfolio", "BalanceStore", "ReportStore", "TradeStore"] 12__all__ = ["Portfolio", "BalanceStore", "ReportStore", "TradeStore"]
12 13
@@ -302,13 +303,16 @@ class BalanceStore:
302 compute_value, type) 303 compute_value, type)
303 return amounts 304 return amounts
304 305
305 def fetch_balances(self, tag=None, log_tickers=False, 306 def fetch_balances(self, tag=None, add_portfolio=False, log_tickers=False,
306 ticker_currency="BTC", ticker_compute_value="average", ticker_type="total"): 307 ticker_currency="BTC", ticker_compute_value="average", ticker_type="total"):
307 all_balances = self.market.ccxt.fetch_all_balances() 308 all_balances = self.market.ccxt.fetch_all_balances()
308 for currency, balance in all_balances.items(): 309 for currency, balance in all_balances.items():
309 if balance["exchange_total"] != 0 or balance["margin_total"] != 0 or \ 310 if balance["exchange_total"] != 0 or balance["margin_total"] != 0 or \
310 currency in self.all: 311 currency in self.all:
311 self.all[currency] = portfolio.Balance(currency, balance) 312 self.all[currency] = portfolio.Balance(currency, balance)
313 if add_portfolio:
314 for currency in Portfolio.repartition(from_cache=True):
315 self.all.setdefault(currency, portfolio.Balance(currency, {}))
312 if log_tickers: 316 if log_tickers:
313 tickers = self.in_currency(ticker_currency, compute_value=ticker_compute_value, type=ticker_type) 317 tickers = self.in_currency(ticker_currency, compute_value=ticker_compute_value, type=ticker_type)
314 self.market.report.log_balances(tag=tag, 318 self.market.report.log_balances(tag=tag,
@@ -508,7 +512,9 @@ class Portfolio:
508 cls.get_cryptoportfolio(refetch=True) 512 cls.get_cryptoportfolio(refetch=True)
509 513
510 @classmethod 514 @classmethod
511 def repartition(cls, liquidity="medium"): 515 def repartition(cls, liquidity="medium", from_cache=False):
516 if from_cache:
517 cls.retrieve_cryptoportfolio()
512 cls.get_cryptoportfolio() 518 cls.get_cryptoportfolio()
513 liquidities = cls.liquidities.get(liquidity) 519 liquidities = cls.liquidities.get(liquidity)
514 return liquidities[cls.last_date.get()] 520 return liquidities[cls.last_date.get()]
@@ -530,12 +536,39 @@ class Portfolio:
530 try: 536 try:
531 cls.data.set(r.json(parse_int=D, parse_float=D)) 537 cls.data.set(r.json(parse_int=D, parse_float=D))
532 cls.parse_cryptoportfolio() 538 cls.parse_cryptoportfolio()
539 cls.store_cryptoportfolio()
533 except (JSONDecodeError, SimpleJSONDecodeError): 540 except (JSONDecodeError, SimpleJSONDecodeError):
534 cls.data.set(None) 541 cls.data.set(None)
535 cls.last_date.set(None) 542 cls.last_date.set(None)
536 cls.liquidities.set({}) 543 cls.liquidities.set({})
537 544
538 @classmethod 545 @classmethod
546 def retrieve_cryptoportfolio(cls):
547 if dbs.redis_connected():
548 repartition = dbs.redis.get("/cryptoportfolio/repartition/latest")
549 date = dbs.redis.get("/cryptoportfolio/repartition/date")
550 if date is not None and repartition is not None:
551 date = datetime.datetime.strptime(date.decode(), "%Y-%m-%d")
552 repartition = json.loads(repartition, parse_int=D, parse_float=D)
553 repartition = { k: { date: v } for k, v in repartition.items() }
554
555 cls.data.set("")
556 cls.last_date.set(date)
557 cls.liquidities.set(repartition)
558
559 @classmethod
560 def store_cryptoportfolio(cls):
561 if dbs.redis_connected():
562 hash_ = {}
563 for liquidity, repartitions in cls.liquidities.items():
564 hash_[liquidity] = repartitions[cls.last_date.get()]
565 dump = json.dumps(hash_)
566 key = "/cryptoportfolio/repartition/latest"
567 dbs.redis.set(key, dump)
568 key = "/cryptoportfolio/repartition/date"
569 dbs.redis.set(key, cls.last_date.date().isoformat())
570
571 @classmethod
539 def parse_cryptoportfolio(cls): 572 def parse_cryptoportfolio(cls):
540 def filter_weights(weight_hash): 573 def filter_weights(weight_hash):
541 if weight_hash[1][0] == 0: 574 if weight_hash[1][0] == 0:
diff --git a/test.py b/test.py
index d7743b2..ed89434 100644
--- a/test.py
+++ b/test.py
@@ -9,6 +9,7 @@ if "unit" in limits:
9 from tests.test_market import * 9 from tests.test_market import *
10 from tests.test_store import * 10 from tests.test_store import *
11 from tests.test_portfolio import * 11 from tests.test_portfolio import *
12 from tests.test_dbs import *
12 13
13if "acceptance" in limits: 14if "acceptance" in limits:
14 from tests.test_acceptance import * 15 from tests.test_acceptance import *
diff --git a/tests/helper.py b/tests/helper.py
index b85bf3a..935e060 100644
--- a/tests/helper.py
+++ b/tests/helper.py
@@ -4,7 +4,7 @@ from decimal import Decimal as D
4from unittest import mock 4from unittest import mock
5import requests_mock 5import requests_mock
6from io import StringIO 6from io import StringIO
7import portfolio, market, main, store 7import portfolio, market, main, store, dbs
8 8
9__all__ = ["limits", "unittest", "WebMockTestCase", "mock", "D", 9__all__ = ["limits", "unittest", "WebMockTestCase", "mock", "D",
10 "StringIO"] 10 "StringIO"]
@@ -48,6 +48,8 @@ class WebMockTestCase(unittest.TestCase):
48 callback=None), 48 callback=None),
49 mock.patch.multiple(portfolio.Computation, 49 mock.patch.multiple(portfolio.Computation,
50 computations=portfolio.Computation.computations), 50 computations=portfolio.Computation.computations),
51 mock.patch.multiple(dbs,
52 redis=None, psql=None)
51 ] 53 ]
52 for patcher in self.patchers: 54 for patcher in self.patchers:
53 patcher.start() 55 patcher.start()
diff --git a/tests/test_dbs.py b/tests/test_dbs.py
new file mode 100644
index 0000000..157c423
--- /dev/null
+++ b/tests/test_dbs.py
@@ -0,0 +1,108 @@
1from .helper import *
2import dbs, main
3
4@unittest.skipUnless("unit" in limits, "Unit skipped")
5class DbsTest(WebMockTestCase):
6 @mock.patch.object(dbs, "psycopg2")
7 def test_connect_psql(self, psycopg2):
8 args = main.configargparse.Namespace(**{
9 "db_host": "host",
10 "db_port": "port",
11 "db_user": "user",
12 "db_password": "password",
13 "db_database": "database",
14 })
15 psycopg2.connect.return_value = "pg_connection"
16 dbs.connect_psql(args)
17
18 psycopg2.connect.assert_called_once_with(host="host",
19 port="port", user="user", password="password",
20 database="database")
21 self.assertEqual("pg_connection", dbs.psql)
22 with self.assertRaises(AttributeError):
23 args.db_password
24
25 psycopg2.connect.reset_mock()
26 args = main.configargparse.Namespace(**{
27 "db_host": "host",
28 "db_port": "port",
29 "db_user": "user",
30 "db_password": "password",
31 "db_database": "database",
32 })
33 dbs.connect_psql(args)
34 psycopg2.connect.assert_not_called()
35
36 @mock.patch.object(dbs, "_redis")
37 def test_connect_redis(self, redis):
38 with self.subTest(redis_host="tcp"):
39 args = main.configargparse.Namespace(**{
40 "redis_host": "host",
41 "redis_port": "port",
42 "redis_database": "database",
43 })
44 redis.Redis.return_value = "redis_connection"
45 dbs.connect_redis(args)
46
47 redis.Redis.assert_called_once_with(host="host",
48 port="port", db="database")
49 self.assertEqual("redis_connection", dbs.redis)
50 with self.assertRaises(AttributeError):
51 args.redis_database
52
53 redis.Redis.reset_mock()
54 args = main.configargparse.Namespace(**{
55 "redis_host": "host",
56 "redis_port": "port",
57 "redis_database": "database",
58 })
59 dbs.connect_redis(args)
60 redis.Redis.assert_not_called()
61
62 dbs.redis = None
63 with self.subTest(redis_host="socket"):
64 args = main.configargparse.Namespace(**{
65 "redis_host": "/run/foo",
66 "redis_port": "port",
67 "redis_database": "database",
68 })
69 redis.Redis.return_value = "redis_socket"
70 dbs.connect_redis(args)
71
72 redis.Redis.assert_called_once_with(unix_socket_path="/run/foo", db="database")
73 self.assertEqual("redis_socket", dbs.redis)
74
75 def test_redis_connected(self):
76 with self.subTest(redis=None):
77 dbs.redis = None
78 self.assertFalse(dbs.redis_connected())
79
80 with self.subTest(redis="mocked_true"):
81 dbs.redis = mock.Mock()
82 dbs.redis.ping.return_value = True
83 self.assertTrue(dbs.redis_connected())
84
85 with self.subTest(redis="mocked_false"):
86 dbs.redis = mock.Mock()
87 dbs.redis.ping.return_value = False
88 self.assertFalse(dbs.redis_connected())
89
90 with self.subTest(redis="mocked_raise"):
91 dbs.redis = mock.Mock()
92 dbs.redis.ping.side_effect = Exception("bouh")
93 self.assertFalse(dbs.redis_connected())
94
95 def test_psql_connected(self):
96 with self.subTest(psql=None):
97 dbs.psql = None
98 self.assertFalse(dbs.psql_connected())
99
100 with self.subTest(psql="connected"):
101 dbs.psql = mock.Mock()
102 dbs.psql.closed = 0
103 self.assertTrue(dbs.psql_connected())
104
105 with self.subTest(psql="not connected"):
106 dbs.psql = mock.Mock()
107 dbs.psql.closed = 3
108 self.assertFalse(dbs.psql_connected())
diff --git a/tests/test_main.py b/tests/test_main.py
index 55b1382..1864c06 100644
--- a/tests/test_main.py
+++ b/tests/test_main.py
@@ -103,7 +103,6 @@ class MainTest(WebMockTestCase):
103 mock.patch("main.parse_config") as main_parse_config: 103 mock.patch("main.parse_config") as main_parse_config:
104 with self.subTest(debug=False): 104 with self.subTest(debug=False):
105 main_parse_args.return_value = self.market_args() 105 main_parse_args.return_value = self.market_args()
106 main_parse_config.return_value = ["pg_config", "redis_config"]
107 main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)] 106 main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)]
108 m = main.get_user_market("config_path.ini", 1) 107 m = main.get_user_market("config_path.ini", 1)
109 108
@@ -114,7 +113,6 @@ class MainTest(WebMockTestCase):
114 main_parse_args.reset_mock() 113 main_parse_args.reset_mock()
115 with self.subTest(debug=True): 114 with self.subTest(debug=True):
116 main_parse_args.return_value = self.market_args(debug=True) 115 main_parse_args.return_value = self.market_args(debug=True)
117 main_parse_config.return_value = ["pg_config", "redis_config"]
118 main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)] 116 main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)]
119 m = main.get_user_market("config_path.ini", 1, debug=True) 117 m = main.get_user_market("config_path.ini", 1, debug=True)
120 118
@@ -135,16 +133,16 @@ class MainTest(WebMockTestCase):
135 args_mock.after = "after" 133 args_mock.after = "after"
136 self.assertEqual("", stdout_mock.getvalue()) 134 self.assertEqual("", stdout_mock.getvalue())
137 135
138 main.process("config", 3, 1, args_mock, "pg_config", "redis_config") 136 main.process("config", 3, 1, args_mock)
139 137
140 market_mock.from_config.assert_has_calls([ 138 market_mock.from_config.assert_has_calls([
141 mock.call("config", args_mock, pg_config="pg_config", redis_config="redis_config", market_id=3, user_id=1), 139 mock.call("config", args_mock, market_id=3, user_id=1),
142 mock.call().process("action", before="before", after="after"), 140 mock.call().process("action", before="before", after="after"),
143 ]) 141 ])
144 142
145 with self.subTest(exception=True): 143 with self.subTest(exception=True):
146 market_mock.from_config.side_effect = Exception("boo") 144 market_mock.from_config.side_effect = Exception("boo")
147 main.process(3, "config", 1, args_mock, "pg_config", "redis_config") 145 main.process(3, "config", 1, args_mock)
148 self.assertEqual("Exception: boo\n", stdout_mock.getvalue()) 146 self.assertEqual("Exception: boo\n", stdout_mock.getvalue())
149 147
150 def test_main(self): 148 def test_main(self):
@@ -159,20 +157,18 @@ class MainTest(WebMockTestCase):
159 args_mock.user = "user" 157 args_mock.user = "user"
160 parse_args.return_value = args_mock 158 parse_args.return_value = args_mock
161 159
162 parse_config.return_value = ["pg_config", "redis_config"]
163
164 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] 160 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]
165 161
166 main.main(["Foo", "Bar"]) 162 main.main(["Foo", "Bar"])
167 163
168 parse_args.assert_called_with(["Foo", "Bar"]) 164 parse_args.assert_called_with(["Foo", "Bar"])
169 parse_config.assert_called_with(args_mock) 165 parse_config.assert_called_with(args_mock)
170 fetch_markets.assert_called_with("pg_config", "user") 166 fetch_markets.assert_called_with("user")
171 167
172 self.assertEqual(2, process.call_count) 168 self.assertEqual(2, process.call_count)
173 process.assert_has_calls([ 169 process.assert_has_calls([
174 mock.call("config1", 3, 1, args_mock, "pg_config", "redis_config"), 170 mock.call("config1", 3, 1, args_mock),
175 mock.call("config2", 1, 2, args_mock, "pg_config", "redis_config"), 171 mock.call("config2", 1, 2, args_mock),
176 ]) 172 ])
177 with self.subTest(parallel=True): 173 with self.subTest(parallel=True):
178 with mock.patch("main.parse_args") as parse_args,\ 174 with mock.patch("main.parse_args") as parse_args,\
@@ -187,24 +183,22 @@ class MainTest(WebMockTestCase):
187 args_mock.user = "user" 183 args_mock.user = "user"
188 parse_args.return_value = args_mock 184 parse_args.return_value = args_mock
189 185
190 parse_config.return_value = ["pg_config", "redis_config"]
191
192 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] 186 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]
193 187
194 main.main(["Foo", "Bar"]) 188 main.main(["Foo", "Bar"])
195 189
196 parse_args.assert_called_with(["Foo", "Bar"]) 190 parse_args.assert_called_with(["Foo", "Bar"])
197 parse_config.assert_called_with(args_mock) 191 parse_config.assert_called_with(args_mock)
198 fetch_markets.assert_called_with("pg_config", "user") 192 fetch_markets.assert_called_with("user")
199 193
200 stop.assert_called_once_with() 194 stop.assert_called_once_with()
201 start.assert_called_once_with() 195 start.assert_called_once_with()
202 self.assertEqual(2, process.call_count) 196 self.assertEqual(2, process.call_count)
203 process.assert_has_calls([ 197 process.assert_has_calls([
204 mock.call.__bool__(), 198 mock.call.__bool__(),
205 mock.call("config1", 3, 1, args_mock, "pg_config", "redis_config"), 199 mock.call("config1", 3, 1, args_mock),
206 mock.call.__bool__(), 200 mock.call.__bool__(),
207 mock.call("config2", 1, 2, args_mock, "pg_config", "redis_config"), 201 mock.call("config2", 1, 2, args_mock),
208 ]) 202 ])
209 with self.subTest(quiet=True): 203 with self.subTest(quiet=True):
210 with mock.patch("main.parse_args") as parse_args,\ 204 with mock.patch("main.parse_args") as parse_args,\
@@ -219,8 +213,6 @@ class MainTest(WebMockTestCase):
219 args_mock.user = "user" 213 args_mock.user = "user"
220 parse_args.return_value = args_mock 214 parse_args.return_value = args_mock
221 215
222 parse_config.return_value = ["pg_config", "redis_config"]
223
224 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] 216 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]
225 217
226 main.main(["Foo", "Bar"]) 218 main.main(["Foo", "Bar"])
@@ -240,8 +232,6 @@ class MainTest(WebMockTestCase):
240 args_mock.user = "user" 232 args_mock.user = "user"
241 parse_args.return_value = args_mock 233 parse_args.return_value = args_mock
242 234
243 parse_config.return_value = ["pg_config", "redis_config"]
244
245 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] 235 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]
246 236
247 main.main(["Foo", "Bar"]) 237 main.main(["Foo", "Bar"])
@@ -252,63 +242,57 @@ class MainTest(WebMockTestCase):
252 @mock.patch.object(main.sys, "exit") 242 @mock.patch.object(main.sys, "exit")
253 @mock.patch("main.os") 243 @mock.patch("main.os")
254 def test_parse_config(self, os, exit): 244 def test_parse_config(self, os, exit):
255 with self.subTest(report_path=None): 245 with self.subTest(report_path=None),\
246 mock.patch.object(main.dbs, "connect_psql") as psql,\
247 mock.patch.object(main.dbs, "connect_redis") as redis:
256 args = main.configargparse.Namespace(**{ 248 args = main.configargparse.Namespace(**{
257 "db_host": "host", 249 "db_host": "host",
258 "db_port": "port",
259 "db_user": "user",
260 "db_password": "password",
261 "db_database": "database",
262 "redis_host": "rhost", 250 "redis_host": "rhost",
263 "redis_port": "rport",
264 "redis_database": "rdb",
265 "report_path": None, 251 "report_path": None,
266 }) 252 })
267 253
268 db_config, redis_config = main.parse_config(args) 254 main.parse_config(args)
269 self.assertEqual({ "host": "host", "port": "port", "user": 255 psql.assert_called_once_with(args)
270 "user", "password": "password", "database": "database" 256 redis.assert_called_once_with(args)
271 }, db_config) 257
272 self.assertEqual({ "host": "rhost", "port": "rport", "db": 258 with self.subTest(report_path=None, db=None),\
273 "rdb"}, redis_config) 259 mock.patch.object(main.dbs, "connect_psql") as psql,\
260 mock.patch.object(main.dbs, "connect_redis") as redis:
261 args = main.configargparse.Namespace(**{
262 "db_host": None,
263 "redis_host": "rhost",
264 "report_path": None,
265 })
274 266
275 with self.assertRaises(AttributeError): 267 main.parse_config(args)
276 args.db_password 268 psql.assert_not_called()
277 with self.assertRaises(AttributeError): 269 redis.assert_called_once_with(args)
278 args.redis_host
279 270
280 with self.subTest(redis_host="socket"): 271 with self.subTest(report_path=None, redis=None),\
272 mock.patch.object(main.dbs, "connect_psql") as psql,\
273 mock.patch.object(main.dbs, "connect_redis") as redis:
281 args = main.configargparse.Namespace(**{ 274 args = main.configargparse.Namespace(**{
282 "db_host": "host", 275 "db_host": "host",
283 "db_port": "port", 276 "redis_host": None,
284 "db_user": "user",
285 "db_password": "password",
286 "db_database": "database",
287 "redis_host": "/run/foo",
288 "redis_port": "rport",
289 "redis_database": "rdb",
290 "report_path": None, 277 "report_path": None,
291 }) 278 })
292 279
293 db_config, redis_config = main.parse_config(args) 280 main.parse_config(args)
294 self.assertEqual({ "unix_socket_path": "/run/foo", "db": "rdb"}, redis_config) 281 redis.assert_not_called()
282 psql.assert_called_once_with(args)
295 283
296 with self.subTest(report_path="present"): 284 with self.subTest(report_path="present"),\
285 mock.patch.object(main.dbs, "connect_psql") as psql,\
286 mock.patch.object(main.dbs, "connect_redis") as redis:
297 args = main.configargparse.Namespace(**{ 287 args = main.configargparse.Namespace(**{
298 "db_host": "host", 288 "db_host": "host",
299 "db_port": "port",
300 "db_user": "user",
301 "db_password": "password",
302 "db_database": "database",
303 "redis_host": "rhost", 289 "redis_host": "rhost",
304 "redis_port": "rport",
305 "redis_database": "rdb",
306 "report_path": "report_path", 290 "report_path": "report_path",
307 }) 291 })
308 292
309 os.path.exists.return_value = False 293 os.path.exists.return_value = False
310 294
311 result = main.parse_config(args) 295 main.parse_config(args)
312 296
313 os.path.exists.assert_called_once_with("report_path") 297 os.path.exists.assert_called_once_with("report_path")
314 os.makedirs.assert_called_once_with("report_path") 298 os.makedirs.assert_called_once_with("report_path")
@@ -331,29 +315,24 @@ class MainTest(WebMockTestCase):
331 mock.patch('sys.stderr', new_callable=StringIO) as stdout_mock: 315 mock.patch('sys.stderr', new_callable=StringIO) as stdout_mock:
332 args = main.parse_args(["--config", "foo.bar"]) 316 args = main.parse_args(["--config", "foo.bar"])
333 317
334 @mock.patch.object(main, "psycopg2") 318 @mock.patch.object(main.dbs, "psql")
335 def test_fetch_markets(self, psycopg2): 319 def test_fetch_markets(self, psql):
336 connect_mock = mock.Mock()
337 cursor_mock = mock.MagicMock() 320 cursor_mock = mock.MagicMock()
338 cursor_mock.__iter__.return_value = ["row_1", "row_2"] 321 cursor_mock.__iter__.return_value = ["row_1", "row_2"]
339 322
340 connect_mock.cursor.return_value = cursor_mock 323 psql.cursor.return_value = cursor_mock
341 psycopg2.connect.return_value = connect_mock
342 324
343 with self.subTest(user=None): 325 with self.subTest(user=None):
344 rows = list(main.fetch_markets({"foo": "bar"}, None)) 326 rows = list(main.fetch_markets(None))
345 327
346 psycopg2.connect.assert_called_once_with(foo="bar")
347 cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs") 328 cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs")
348 329
349 self.assertEqual(["row_1", "row_2"], rows) 330 self.assertEqual(["row_1", "row_2"], rows)
350 331
351 psycopg2.connect.reset_mock()
352 cursor_mock.execute.reset_mock() 332 cursor_mock.execute.reset_mock()
353 with self.subTest(user=1): 333 with self.subTest(user=1):
354 rows = list(main.fetch_markets({"foo": "bar"}, 1)) 334 rows = list(main.fetch_markets(1))
355 335
356 psycopg2.connect.assert_called_once_with(foo="bar")
357 cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs WHERE user_id = %s", 1) 336 cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs WHERE user_id = %s", 1)
358 337
359 self.assertEqual(["row_1", "row_2"], rows) 338 self.assertEqual(["row_1", "row_2"], rows)
diff --git a/tests/test_market.py b/tests/test_market.py
index 6a3322c..46fad53 100644
--- a/tests/test_market.py
+++ b/tests/test_market.py
@@ -1,5 +1,5 @@
1from .helper import * 1from .helper import *
2import market, store, portfolio 2import market, store, portfolio, dbs
3import datetime 3import datetime
4 4
5@unittest.skipUnless("unit" in limits, "Unit skipped") 5@unittest.skipUnless("unit" in limits, "Unit skipped")
@@ -595,13 +595,11 @@ class MarketTest(WebMockTestCase):
595 595
596 self.assertRegex(stdout_mock.getvalue(), "impossible to store report file: FileNotFoundError;") 596 self.assertRegex(stdout_mock.getvalue(), "impossible to store report file: FileNotFoundError;")
597 597
598 @mock.patch.object(market, "psycopg2") 598 @mock.patch.object(dbs, "psql")
599 def test_store_database_report(self, psycopg2): 599 def test_store_database_report(self, psql):
600 connect_mock = mock.Mock()
601 cursor_mock = mock.MagicMock() 600 cursor_mock = mock.MagicMock()
602 601
603 connect_mock.cursor.return_value = cursor_mock 602 psql.cursor.return_value = cursor_mock
604 psycopg2.connect.return_value = connect_mock
605 m = market.Market(self.ccxt, self.market_args(), 603 m = market.Market(self.ccxt, self.market_args(),
606 pg_config={"config": "pg_config"}, user_id=1) 604 pg_config={"config": "pg_config"}, user_id=1)
607 cursor_mock.fetchone.return_value = [42] 605 cursor_mock.fetchone.return_value = [42]
@@ -613,7 +611,7 @@ class MarketTest(WebMockTestCase):
613 ("date2", "type2", "payload2"), 611 ("date2", "type2", "payload2"),
614 ] 612 ]
615 m.store_database_report(datetime.datetime(2018, 3, 24)) 613 m.store_database_report(datetime.datetime(2018, 3, 24))
616 connect_mock.assert_has_calls([ 614 psql.assert_has_calls([
617 mock.call.cursor(), 615 mock.call.cursor(),
618 mock.call.cursor().execute('INSERT INTO reports("date", "market_config_id", "debug") VALUES (%s, %s, %s) RETURNING id;', (datetime.datetime(2018, 3, 24), None, False)), 616 mock.call.cursor().execute('INSERT INTO reports("date", "market_config_id", "debug") VALUES (%s, %s, %s) RETURNING id;', (datetime.datetime(2018, 3, 24), None, False)),
619 mock.call.cursor().fetchone(), 617 mock.call.cursor().fetchone(),
@@ -621,21 +619,16 @@ class MarketTest(WebMockTestCase):
621 mock.call.cursor().execute('INSERT INTO report_lines("date", "report_id", "type", "payload") VALUES (%s, %s, %s, %s);', ('date2', 42, 'type2', 'payload2')), 619 mock.call.cursor().execute('INSERT INTO report_lines("date", "report_id", "type", "payload") VALUES (%s, %s, %s, %s);', ('date2', 42, 'type2', 'payload2')),
622 mock.call.commit(), 620 mock.call.commit(),
623 mock.call.cursor().close(), 621 mock.call.cursor().close(),
624 mock.call.close()
625 ]) 622 ])
626 623
627 connect_mock.reset_mock()
628 with self.subTest(error=True),\ 624 with self.subTest(error=True),\
629 mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock: 625 mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock:
630 psycopg2.connect.side_effect = Exception("Bouh") 626 psql.cursor.side_effect = Exception("Bouh")
631 m.store_database_report(datetime.datetime(2018, 3, 24)) 627 m.store_database_report(datetime.datetime(2018, 3, 24))
632 self.assertEqual(stdout_mock.getvalue(), "impossible to store report to database: Exception; Bouh\n") 628 self.assertEqual(stdout_mock.getvalue(), "impossible to store report to database: Exception; Bouh\n")
633 629
634 @mock.patch.object(market, "redis") 630 @mock.patch.object(dbs, "redis")
635 def test_store_redis_report(self, redis): 631 def test_store_redis_report(self, redis):
636 connect_mock = mock.Mock()
637 redis.Redis.return_value = connect_mock
638
639 m = market.Market(self.ccxt, self.market_args(), 632 m = market.Market(self.ccxt, self.market_args(),
640 redis_config={"config": "redis_config"}, market_id=1) 633 redis_config={"config": "redis_config"}, market_id=1)
641 634
@@ -646,7 +639,7 @@ class MarketTest(WebMockTestCase):
646 ("type2", "payload2"), 639 ("type2", "payload2"),
647 ] 640 ]
648 m.store_redis_report(datetime.datetime(2018, 3, 24)) 641 m.store_redis_report(datetime.datetime(2018, 3, 24))
649 connect_mock.assert_has_calls([ 642 redis.assert_has_calls([
650 mock.call.set("/cryptoportfolio/1/2018-03-24T00:00:00/type1", "payload1", ex=31*24*60*60), 643 mock.call.set("/cryptoportfolio/1/2018-03-24T00:00:00/type1", "payload1", ex=31*24*60*60),
651 mock.call.set("/cryptoportfolio/1/latest/type1", "payload1"), 644 mock.call.set("/cryptoportfolio/1/latest/type1", "payload1"),
652 mock.call.set("/cryptoportfolio/1/2018-03-24T00:00:00/type2", "payload2", ex=31*24*60*60), 645 mock.call.set("/cryptoportfolio/1/2018-03-24T00:00:00/type2", "payload2", ex=31*24*60*60),
@@ -654,20 +647,24 @@ class MarketTest(WebMockTestCase):
654 mock.call.set("/cryptoportfolio/1/latest/date", "2018-03-24T00:00:00"), 647 mock.call.set("/cryptoportfolio/1/latest/date", "2018-03-24T00:00:00"),
655 ]) 648 ])
656 649
657 connect_mock.reset_mock() 650 redis.reset_mock()
658 with self.subTest(error=True),\ 651 with self.subTest(error=True),\
659 mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock: 652 mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock:
660 redis.Redis.side_effect = Exception("Bouh") 653 redis.set.side_effect = Exception("Bouh")
661 m.store_redis_report(datetime.datetime(2018, 3, 24)) 654 m.store_redis_report(datetime.datetime(2018, 3, 24))
662 self.assertEqual(stdout_mock.getvalue(), "impossible to store report to redis: Exception; Bouh\n") 655 self.assertEqual(stdout_mock.getvalue(), "impossible to store report to redis: Exception; Bouh\n")
663 656
664 def test_store_report(self): 657 def test_store_report(self):
665 m = market.Market(self.ccxt, self.market_args(report_db=False), user_id=1) 658 m = market.Market(self.ccxt, self.market_args(report_db=False), user_id=1)
666 with self.subTest(file=None, pg_config=None),\ 659 with self.subTest(file=None, pg_connected=None),\
660 mock.patch.object(dbs, "psql_connected") as psql,\
661 mock.patch.object(dbs, "redis_connected") as redis,\
667 mock.patch.object(m, "report") as report,\ 662 mock.patch.object(m, "report") as report,\
668 mock.patch.object(m, "store_database_report") as db_report,\ 663 mock.patch.object(m, "store_database_report") as db_report,\
669 mock.patch.object(m, "store_redis_report") as redis_report,\ 664 mock.patch.object(m, "store_redis_report") as redis_report,\
670 mock.patch.object(m, "store_file_report") as file_report: 665 mock.patch.object(m, "store_file_report") as file_report:
666 psql.return_value = False
667 redis.return_value = False
671 m.store_report() 668 m.store_report()
672 report.merge.assert_called_with(store.Portfolio.report) 669 report.merge.assert_called_with(store.Portfolio.report)
673 670
@@ -677,13 +674,16 @@ class MarketTest(WebMockTestCase):
677 674
678 report.reset_mock() 675 report.reset_mock()
679 m = market.Market(self.ccxt, self.market_args(report_db=False, report_path="present"), user_id=1) 676 m = market.Market(self.ccxt, self.market_args(report_db=False, report_path="present"), user_id=1)
680 with self.subTest(file="present", pg_config=None),\ 677 with self.subTest(file="present", pg_connected=None),\
678 mock.patch.object(dbs, "psql_connected") as psql,\
679 mock.patch.object(dbs, "redis_connected") as redis,\
681 mock.patch.object(m, "report") as report,\ 680 mock.patch.object(m, "report") as report,\
682 mock.patch.object(m, "store_file_report") as file_report,\ 681 mock.patch.object(m, "store_file_report") as file_report,\
683 mock.patch.object(m, "store_redis_report") as redis_report,\ 682 mock.patch.object(m, "store_redis_report") as redis_report,\
684 mock.patch.object(m, "store_database_report") as db_report,\ 683 mock.patch.object(m, "store_database_report") as db_report,\
685 mock.patch.object(market.datetime, "datetime") as time_mock: 684 mock.patch.object(market.datetime, "datetime") as time_mock:
686 685 psql.return_value = False
686 redis.return_value = False
687 time_mock.now.return_value = datetime.datetime(2018, 2, 25) 687 time_mock.now.return_value = datetime.datetime(2018, 2, 25)
688 688
689 m.store_report() 689 m.store_report()
@@ -695,13 +695,16 @@ class MarketTest(WebMockTestCase):
695 695
696 report.reset_mock() 696 report.reset_mock()
697 m = market.Market(self.ccxt, self.market_args(report_db=True, report_path="present"), user_id=1) 697 m = market.Market(self.ccxt, self.market_args(report_db=True, report_path="present"), user_id=1)
698 with self.subTest(file="present", pg_config=None, report_db=True),\ 698 with self.subTest(file="present", pg_connected=None, report_db=True),\
699 mock.patch.object(dbs, "psql_connected") as psql,\
700 mock.patch.object(dbs, "redis_connected") as redis,\
699 mock.patch.object(m, "report") as report,\ 701 mock.patch.object(m, "report") as report,\
700 mock.patch.object(m, "store_file_report") as file_report,\ 702 mock.patch.object(m, "store_file_report") as file_report,\
701 mock.patch.object(m, "store_redis_report") as redis_report,\ 703 mock.patch.object(m, "store_redis_report") as redis_report,\
702 mock.patch.object(m, "store_database_report") as db_report,\ 704 mock.patch.object(m, "store_database_report") as db_report,\
703 mock.patch.object(market.datetime, "datetime") as time_mock: 705 mock.patch.object(market.datetime, "datetime") as time_mock:
704 706 psql.return_value = False
707 redis.return_value = False
705 time_mock.now.return_value = datetime.datetime(2018, 2, 25) 708 time_mock.now.return_value = datetime.datetime(2018, 2, 25)
706 709
707 m.store_report() 710 m.store_report()
@@ -712,14 +715,17 @@ class MarketTest(WebMockTestCase):
712 redis_report.assert_not_called() 715 redis_report.assert_not_called()
713 716
714 report.reset_mock() 717 report.reset_mock()
715 m = market.Market(self.ccxt, self.market_args(report_db=True), pg_config="present", user_id=1) 718 m = market.Market(self.ccxt, self.market_args(report_db=True), user_id=1)
716 with self.subTest(file=None, pg_config="present"),\ 719 with self.subTest(file=None, pg_connected=True),\
720 mock.patch.object(dbs, "psql_connected") as psql,\
721 mock.patch.object(dbs, "redis_connected") as redis,\
717 mock.patch.object(m, "report") as report,\ 722 mock.patch.object(m, "report") as report,\
718 mock.patch.object(m, "store_file_report") as file_report,\ 723 mock.patch.object(m, "store_file_report") as file_report,\
719 mock.patch.object(m, "store_redis_report") as redis_report,\ 724 mock.patch.object(m, "store_redis_report") as redis_report,\
720 mock.patch.object(m, "store_database_report") as db_report,\ 725 mock.patch.object(m, "store_database_report") as db_report,\
721 mock.patch.object(market.datetime, "datetime") as time_mock: 726 mock.patch.object(market.datetime, "datetime") as time_mock:
722 727 psql.return_value = True
728 redis.return_value = False
723 time_mock.now.return_value = datetime.datetime(2018, 2, 25) 729 time_mock.now.return_value = datetime.datetime(2018, 2, 25)
724 730
725 m.store_report() 731 m.store_report()
@@ -731,14 +737,17 @@ class MarketTest(WebMockTestCase):
731 737
732 report.reset_mock() 738 report.reset_mock()
733 m = market.Market(self.ccxt, self.market_args(report_db=True, report_path="present"), 739 m = market.Market(self.ccxt, self.market_args(report_db=True, report_path="present"),
734 pg_config="pg_config", user_id=1) 740 user_id=1)
735 with self.subTest(file="present", pg_config="present"),\ 741 with self.subTest(file="present", pg_connected=True),\
742 mock.patch.object(dbs, "psql_connected") as psql,\
743 mock.patch.object(dbs, "redis_connected") as redis,\
736 mock.patch.object(m, "report") as report,\ 744 mock.patch.object(m, "report") as report,\
737 mock.patch.object(m, "store_file_report") as file_report,\ 745 mock.patch.object(m, "store_file_report") as file_report,\
738 mock.patch.object(m, "store_redis_report") as redis_report,\ 746 mock.patch.object(m, "store_redis_report") as redis_report,\
739 mock.patch.object(m, "store_database_report") as db_report,\ 747 mock.patch.object(m, "store_database_report") as db_report,\
740 mock.patch.object(market.datetime, "datetime") as time_mock: 748 mock.patch.object(market.datetime, "datetime") as time_mock:
741 749 psql.return_value = True
750 redis.return_value = False
742 time_mock.now.return_value = datetime.datetime(2018, 2, 25) 751 time_mock.now.return_value = datetime.datetime(2018, 2, 25)
743 752
744 m.store_report() 753 m.store_report()
@@ -750,14 +759,17 @@ class MarketTest(WebMockTestCase):
750 759
751 report.reset_mock() 760 report.reset_mock()
752 m = market.Market(self.ccxt, self.market_args(report_redis=False), 761 m = market.Market(self.ccxt, self.market_args(report_redis=False),
753 redis_config="redis_config", user_id=1) 762 user_id=1)
754 with self.subTest(redis_config="present", report_redis=False),\ 763 with self.subTest(redis_connected=True, report_redis=False),\
764 mock.patch.object(dbs, "psql_connected") as psql,\
765 mock.patch.object(dbs, "redis_connected") as redis,\
755 mock.patch.object(m, "report") as report,\ 766 mock.patch.object(m, "report") as report,\
756 mock.patch.object(m, "store_file_report") as file_report,\ 767 mock.patch.object(m, "store_file_report") as file_report,\
757 mock.patch.object(m, "store_redis_report") as redis_report,\ 768 mock.patch.object(m, "store_redis_report") as redis_report,\
758 mock.patch.object(m, "store_database_report") as db_report,\ 769 mock.patch.object(m, "store_database_report") as db_report,\
759 mock.patch.object(market.datetime, "datetime") as time_mock: 770 mock.patch.object(market.datetime, "datetime") as time_mock:
760 771 psql.return_value = False
772 redis.return_value = True
761 time_mock.now.return_value = datetime.datetime(2018, 2, 25) 773 time_mock.now.return_value = datetime.datetime(2018, 2, 25)
762 774
763 m.store_report() 775 m.store_report()
@@ -766,13 +778,16 @@ class MarketTest(WebMockTestCase):
766 report.reset_mock() 778 report.reset_mock()
767 m = market.Market(self.ccxt, self.market_args(report_redis=True), 779 m = market.Market(self.ccxt, self.market_args(report_redis=True),
768 user_id=1) 780 user_id=1)
769 with self.subTest(redis_config="absent", report_redis=True),\ 781 with self.subTest(redis_connected=False, report_redis=True),\
782 mock.patch.object(dbs, "psql_connected") as psql,\
783 mock.patch.object(dbs, "redis_connected") as redis,\
770 mock.patch.object(m, "report") as report,\ 784 mock.patch.object(m, "report") as report,\
771 mock.patch.object(m, "store_file_report") as file_report,\ 785 mock.patch.object(m, "store_file_report") as file_report,\
772 mock.patch.object(m, "store_redis_report") as redis_report,\ 786 mock.patch.object(m, "store_redis_report") as redis_report,\
773 mock.patch.object(m, "store_database_report") as db_report,\ 787 mock.patch.object(m, "store_database_report") as db_report,\
774 mock.patch.object(market.datetime, "datetime") as time_mock: 788 mock.patch.object(market.datetime, "datetime") as time_mock:
775 789 psql.return_value = False
790 redis.return_value = False
776 time_mock.now.return_value = datetime.datetime(2018, 2, 25) 791 time_mock.now.return_value = datetime.datetime(2018, 2, 25)
777 792
778 m.store_report() 793 m.store_report()
@@ -780,14 +795,17 @@ class MarketTest(WebMockTestCase):
780 795
781 report.reset_mock() 796 report.reset_mock()
782 m = market.Market(self.ccxt, self.market_args(report_redis=True), 797 m = market.Market(self.ccxt, self.market_args(report_redis=True),
783 redis_config="redis_config", user_id=1) 798 user_id=1)
784 with self.subTest(redis_config="present", report_redis=True),\ 799 with self.subTest(redis_connected=True, report_redis=True),\
800 mock.patch.object(dbs, "psql_connected") as psql,\
801 mock.patch.object(dbs, "redis_connected") as redis,\
785 mock.patch.object(m, "report") as report,\ 802 mock.patch.object(m, "report") as report,\
786 mock.patch.object(m, "store_file_report") as file_report,\ 803 mock.patch.object(m, "store_file_report") as file_report,\
787 mock.patch.object(m, "store_redis_report") as redis_report,\ 804 mock.patch.object(m, "store_redis_report") as redis_report,\
788 mock.patch.object(m, "store_database_report") as db_report,\ 805 mock.patch.object(m, "store_database_report") as db_report,\
789 mock.patch.object(market.datetime, "datetime") as time_mock: 806 mock.patch.object(market.datetime, "datetime") as time_mock:
790 807 psql.return_value = False
808 redis.return_value = True
791 time_mock.now.return_value = datetime.datetime(2018, 2, 25) 809 time_mock.now.return_value = datetime.datetime(2018, 2, 25)
792 810
793 m.store_report() 811 m.store_report()
@@ -1014,6 +1032,15 @@ class ProcessorTest(WebMockTestCase):
1014 processor.process_step("foo", step, {"foo":"bar"}) 1032 processor.process_step("foo", step, {"foo":"bar"})
1015 self.m.balances.fetch_balances.assert_not_called() 1033 self.m.balances.fetch_balances.assert_not_called()
1016 1034
1035 self.m.reset_mock()
1036 with mock.patch.object(processor, "run_action") as run_action:
1037 step = processor.scenarios["print_balances"][0]
1038
1039 processor.process_step("foo", step, {"foo":"bar"})
1040 self.m.balances.fetch_balances.assert_called_once_with(
1041 add_portfolio=True, log_tickers=True,
1042 tag='process_foo__1_print_balances_begin')
1043
1017 def test_parse_args(self): 1044 def test_parse_args(self):
1018 processor = market.Processor(self.m) 1045 processor = market.Processor(self.m)
1019 1046
diff --git a/tests/test_store.py b/tests/test_store.py
index 12999d3..ee7e063 100644
--- a/tests/test_store.py
+++ b/tests/test_store.py
@@ -391,6 +391,18 @@ class BalanceStoreTest(WebMockTestCase):
391 tag=None, ticker_currency='FOO', tickers='tickers', 391 tag=None, ticker_currency='FOO', tickers='tickers',
392 type='type') 392 type='type')
393 393
394 balance_store = market.BalanceStore(self.m)
395
396 with self.subTest(add_portfolio=True),\
397 mock.patch.object(market.Portfolio, "repartition") as repartition:
398 repartition.return_value = {
399 "DOGE": D("0.5"),
400 "USDT": D("0.5"),
401 }
402 balance_store.fetch_balances(add_portfolio=True)
403 self.assertListEqual(["USDT", "XVG", "XMR", "DOGE"], list(balance_store.currencies()))
404
405
394 @mock.patch.object(market.Portfolio, "repartition") 406 @mock.patch.object(market.Portfolio, "repartition")
395 def test_dispatch_assets(self, repartition): 407 def test_dispatch_assets(self, repartition):
396 self.m.ccxt.fetch_all_balances.return_value = self.fetch_balance 408 self.m.ccxt.fetch_all_balances.return_value = self.fetch_balance
@@ -1101,7 +1113,8 @@ class PortfolioTest(WebMockTestCase):
1101 self.wm.get(market.Portfolio.URL, text=self.json_response) 1113 self.wm.get(market.Portfolio.URL, text=self.json_response)
1102 1114
1103 @mock.patch.object(market.Portfolio, "parse_cryptoportfolio") 1115 @mock.patch.object(market.Portfolio, "parse_cryptoportfolio")
1104 def test_get_cryptoportfolio(self, parse_cryptoportfolio): 1116 @mock.patch.object(market.Portfolio, "store_cryptoportfolio")
1117 def test_get_cryptoportfolio(self, store_cryptoportfolio, parse_cryptoportfolio):
1105 with self.subTest(parallel=False): 1118 with self.subTest(parallel=False):
1106 self.wm.get(market.Portfolio.URL, [ 1119 self.wm.get(market.Portfolio.URL, [
1107 {"text":'{ "foo": "bar" }', "status_code": 200}, 1120 {"text":'{ "foo": "bar" }', "status_code": 200},
@@ -1116,23 +1129,28 @@ class PortfolioTest(WebMockTestCase):
1116 market.Portfolio.report.log_error.assert_not_called() 1129 market.Portfolio.report.log_error.assert_not_called()
1117 market.Portfolio.report.log_http_request.assert_called_once() 1130 market.Portfolio.report.log_http_request.assert_called_once()
1118 parse_cryptoportfolio.assert_called_once_with() 1131 parse_cryptoportfolio.assert_called_once_with()
1132 store_cryptoportfolio.assert_called_once_with()
1119 market.Portfolio.report.log_http_request.reset_mock() 1133 market.Portfolio.report.log_http_request.reset_mock()
1120 parse_cryptoportfolio.reset_mock() 1134 parse_cryptoportfolio.reset_mock()
1135 store_cryptoportfolio.reset_mock()
1121 market.Portfolio.data = store.LockedVar(None) 1136 market.Portfolio.data = store.LockedVar(None)
1122 1137
1123 market.Portfolio.get_cryptoportfolio() 1138 market.Portfolio.get_cryptoportfolio()
1124 self.assertIsNone(market.Portfolio.data.get()) 1139 self.assertIsNone(market.Portfolio.data.get())
1125 self.assertEqual(2, self.wm.call_count) 1140 self.assertEqual(2, self.wm.call_count)
1126 parse_cryptoportfolio.assert_not_called() 1141 parse_cryptoportfolio.assert_not_called()
1142 store_cryptoportfolio.assert_not_called()
1127 market.Portfolio.report.log_error.assert_not_called() 1143 market.Portfolio.report.log_error.assert_not_called()
1128 market.Portfolio.report.log_http_request.assert_called_once() 1144 market.Portfolio.report.log_http_request.assert_called_once()
1129 market.Portfolio.report.log_http_request.reset_mock() 1145 market.Portfolio.report.log_http_request.reset_mock()
1130 parse_cryptoportfolio.reset_mock() 1146 parse_cryptoportfolio.reset_mock()
1147 store_cryptoportfolio.reset_mock()
1131 1148
1132 market.Portfolio.data = store.LockedVar("Foo") 1149 market.Portfolio.data = store.LockedVar("Foo")
1133 market.Portfolio.get_cryptoportfolio() 1150 market.Portfolio.get_cryptoportfolio()
1134 self.assertEqual(2, self.wm.call_count) 1151 self.assertEqual(2, self.wm.call_count)
1135 parse_cryptoportfolio.assert_not_called() 1152 parse_cryptoportfolio.assert_not_called()
1153 store_cryptoportfolio.assert_not_called()
1136 1154
1137 market.Portfolio.get_cryptoportfolio(refetch=True) 1155 market.Portfolio.get_cryptoportfolio(refetch=True)
1138 self.assertEqual("Foo", market.Portfolio.data.get()) 1156 self.assertEqual("Foo", market.Portfolio.data.get())
@@ -1153,6 +1171,7 @@ class PortfolioTest(WebMockTestCase):
1153 market.Portfolio.get_cryptoportfolio() 1171 market.Portfolio.get_cryptoportfolio()
1154 self.assertIn("foo", market.Portfolio.data.get()) 1172 self.assertIn("foo", market.Portfolio.data.get())
1155 parse_cryptoportfolio.reset_mock() 1173 parse_cryptoportfolio.reset_mock()
1174 store_cryptoportfolio.reset_mock()
1156 with self.subTest(worker=False): 1175 with self.subTest(worker=False):
1157 market.Portfolio.data = store.LockedVar(None) 1176 market.Portfolio.data = store.LockedVar(None)
1158 market.Portfolio.worker = mock.Mock() 1177 market.Portfolio.worker = mock.Mock()
@@ -1160,6 +1179,7 @@ class PortfolioTest(WebMockTestCase):
1160 market.Portfolio.get_cryptoportfolio() 1179 market.Portfolio.get_cryptoportfolio()
1161 notify.assert_called_once_with() 1180 notify.assert_called_once_with()
1162 parse_cryptoportfolio.assert_not_called() 1181 parse_cryptoportfolio.assert_not_called()
1182 store_cryptoportfolio.assert_not_called()
1163 1183
1164 def test_parse_cryptoportfolio(self): 1184 def test_parse_cryptoportfolio(self):
1165 with self.subTest(description="Normal case"): 1185 with self.subTest(description="Normal case"):
@@ -1223,25 +1243,95 @@ class PortfolioTest(WebMockTestCase):
1223 self.assertEqual({}, market.Portfolio.liquidities.get("high")) 1243 self.assertEqual({}, market.Portfolio.liquidities.get("high"))
1224 self.assertEqual(datetime.datetime(1,1,1), market.Portfolio.last_date.get()) 1244 self.assertEqual(datetime.datetime(1,1,1), market.Portfolio.last_date.get())
1225 1245
1226 1246 @mock.patch.object(store.dbs, "redis_connected")
1227 @mock.patch.object(market.Portfolio, "get_cryptoportfolio") 1247 @mock.patch.object(store.dbs, "redis")
1228 def test_repartition(self, get_cryptoportfolio): 1248 def test_store_cryptoportfolio(self, redis, redis_connected):
1229 market.Portfolio.liquidities = store.LockedVar({ 1249 store.Portfolio.liquidities = store.LockedVar({
1230 "medium": { 1250 "medium": {
1231 "2018-03-01": "medium_2018-03-01", 1251 datetime.datetime(2018,3,1): "medium_2018-03-01",
1232 "2018-03-08": "medium_2018-03-08", 1252 datetime.datetime(2018,3,8): "medium_2018-03-08",
1233 }, 1253 },
1234 "high": { 1254 "high": {
1235 "2018-03-01": "high_2018-03-01", 1255 datetime.datetime(2018,3,1): "high_2018-03-01",
1236 "2018-03-08": "high_2018-03-08", 1256 datetime.datetime(2018,3,8): "high_2018-03-08",
1237 } 1257 }
1238 }) 1258 })
1239 market.Portfolio.last_date = store.LockedVar("2018-03-08") 1259 store.Portfolio.last_date = store.LockedVar(datetime.datetime(2018,3,8))
1260
1261 with self.subTest(redis_connected=False):
1262 redis_connected.return_value = False
1263 store.Portfolio.store_cryptoportfolio()
1264 redis.set.assert_not_called()
1265
1266 with self.subTest(redis_connected=True):
1267 redis_connected.return_value = True
1268 store.Portfolio.store_cryptoportfolio()
1269 redis.set.assert_has_calls([
1270 mock.call("/cryptoportfolio/repartition/latest", '{"medium": "medium_2018-03-08", "high": "high_2018-03-08"}'),
1271 mock.call("/cryptoportfolio/repartition/date", "2018-03-08"),
1272 ])
1273
1274 @mock.patch.object(store.dbs, "redis_connected")
1275 @mock.patch.object(store.dbs, "redis")
1276 def test_retrieve_cryptoportfolio(self, redis, redis_connected):
1277 with self.subTest(redis_connected=False):
1278 redis_connected.return_value = False
1279 store.Portfolio.retrieve_cryptoportfolio()
1280 redis.get.assert_not_called()
1281 self.assertIsNone(store.Portfolio.data.get())
1282
1283 with self.subTest(redis_connected=True, value=None):
1284 redis_connected.return_value = True
1285 redis.get.return_value = None
1286 store.Portfolio.retrieve_cryptoportfolio()
1287 self.assertEqual(2, redis.get.call_count)
1288
1289 redis.reset_mock()
1290 with self.subTest(redis_connected=True, value="present"):
1291 redis_connected.return_value = True
1292 redis.get.side_effect = [
1293 b'{ "medium": "medium_repartition", "high": "high_repartition" }',
1294 b"2018-03-08"
1295 ]
1296 store.Portfolio.retrieve_cryptoportfolio()
1297 self.assertEqual(2, redis.get.call_count)
1298 self.assertEqual(datetime.datetime(2018,3,8), store.Portfolio.last_date.get())
1299 self.assertEqual("", store.Portfolio.data.get())
1300 expected_liquidities = {
1301 'high': { datetime.datetime(2018, 3, 8): 'high_repartition' },
1302 'medium': { datetime.datetime(2018, 3, 8): 'medium_repartition' },
1303 }
1304 self.assertEqual(expected_liquidities, store.Portfolio.liquidities.get())
1305
1306 @mock.patch.object(market.Portfolio, "get_cryptoportfolio")
1307 @mock.patch.object(market.Portfolio, "retrieve_cryptoportfolio")
1308 def test_repartition(self, retrieve_cryptoportfolio, get_cryptoportfolio):
1309 with self.subTest(from_cache=False):
1310 market.Portfolio.liquidities = store.LockedVar({
1311 "medium": {
1312 "2018-03-01": "medium_2018-03-01",
1313 "2018-03-08": "medium_2018-03-08",
1314 },
1315 "high": {
1316 "2018-03-01": "high_2018-03-01",
1317 "2018-03-08": "high_2018-03-08",
1318 }
1319 })
1320 market.Portfolio.last_date = store.LockedVar("2018-03-08")
1321
1322 self.assertEqual("medium_2018-03-08", market.Portfolio.repartition())
1323 get_cryptoportfolio.assert_called_once_with()
1324 retrieve_cryptoportfolio.assert_not_called()
1325 self.assertEqual("medium_2018-03-08", market.Portfolio.repartition(liquidity="medium"))
1326 self.assertEqual("high_2018-03-08", market.Portfolio.repartition(liquidity="high"))
1327
1328 retrieve_cryptoportfolio.reset_mock()
1329 get_cryptoportfolio.reset_mock()
1240 1330
1241 self.assertEqual("medium_2018-03-08", market.Portfolio.repartition()) 1331 with self.subTest(from_cache=True):
1242 get_cryptoportfolio.assert_called_once_with() 1332 self.assertEqual("medium_2018-03-08", market.Portfolio.repartition(from_cache=True))
1243 self.assertEqual("medium_2018-03-08", market.Portfolio.repartition(liquidity="medium")) 1333 get_cryptoportfolio.assert_called_once_with()
1244 self.assertEqual("high_2018-03-08", market.Portfolio.repartition(liquidity="high")) 1334 retrieve_cryptoportfolio.assert_called_once_with()
1245 1335
1246 @mock.patch.object(market.time, "sleep") 1336 @mock.patch.object(market.time, "sleep")
1247 @mock.patch.object(market.Portfolio, "get_cryptoportfolio") 1337 @mock.patch.object(market.Portfolio, "get_cryptoportfolio")