diff options
author | Ismaël Bouya <ismael.bouya@normalesup.org> | 2018-05-03 00:23:26 +0200 |
---|---|---|
committer | Ismaël Bouya <ismael.bouya@normalesup.org> | 2018-05-03 00:23:26 +0200 |
commit | 56c3a6078a2740d43072dfe30f07b17b442e3cc2 (patch) | |
tree | a534f8d64d247d50929f78c92c74acbe4131a1d2 | |
parent | 2b1ee8f4d54fa1672510141a71a5817120ac031c (diff) | |
parent | 9b69786341d14fd4327b117a12437fd1650cd965 (diff) | |
download | Trader-56c3a6078a2740d43072dfe30f07b17b442e3cc2.tar.gz Trader-56c3a6078a2740d43072dfe30f07b17b442e3cc2.tar.zst Trader-56c3a6078a2740d43072dfe30f07b17b442e3cc2.zip |
Merge branch 'refactor_db' into dev
-rw-r--r-- | dbs.py | 55 | ||||
-rw-r--r-- | main.py | 55 | ||||
-rw-r--r-- | market.py | 31 | ||||
-rw-r--r-- | store.py | 37 | ||||
-rw-r--r-- | test.py | 1 | ||||
-rw-r--r-- | tests/helper.py | 4 | ||||
-rw-r--r-- | tests/test_dbs.py | 108 | ||||
-rw-r--r-- | tests/test_main.py | 107 | ||||
-rw-r--r-- | tests/test_market.py | 99 | ||||
-rw-r--r-- | tests/test_store.py | 118 |
10 files changed, 443 insertions, 172 deletions
@@ -0,0 +1,55 @@ | |||
1 | import psycopg2 | ||
2 | import redis as _redis | ||
3 | |||
4 | redis = None | ||
5 | psql = None | ||
6 | |||
7 | def redis_connected(): | ||
8 | global redis | ||
9 | if redis is None: | ||
10 | return False | ||
11 | else: | ||
12 | try: | ||
13 | return redis.ping() | ||
14 | except Exception: | ||
15 | return False | ||
16 | |||
17 | def psql_connected(): | ||
18 | global psql | ||
19 | return psql is not None and psql.closed == 0 | ||
20 | |||
21 | def connect_redis(args): | ||
22 | global redis | ||
23 | |||
24 | redis_config = { | ||
25 | "host": args.redis_host, | ||
26 | "port": args.redis_port, | ||
27 | "db": args.redis_database, | ||
28 | } | ||
29 | if redis_config["host"].startswith("/"): | ||
30 | redis_config["unix_socket_path"] = redis_config.pop("host") | ||
31 | del(redis_config["port"]) | ||
32 | del(args.redis_host) | ||
33 | del(args.redis_port) | ||
34 | del(args.redis_database) | ||
35 | |||
36 | if redis is None: | ||
37 | redis = _redis.Redis(**redis_config) | ||
38 | |||
39 | def connect_psql(args): | ||
40 | global psql | ||
41 | pg_config = { | ||
42 | "host": args.db_host, | ||
43 | "port": args.db_port, | ||
44 | "user": args.db_user, | ||
45 | "password": args.db_password, | ||
46 | "database": args.db_database, | ||
47 | } | ||
48 | del(args.db_host) | ||
49 | del(args.db_port) | ||
50 | del(args.db_user) | ||
51 | del(args.db_password) | ||
52 | del(args.db_database) | ||
53 | |||
54 | if psql is None: | ||
55 | psql = psycopg2.connect(**pg_config) | ||
@@ -1,5 +1,5 @@ | |||
1 | import configargparse | 1 | import configargparse |
2 | import psycopg2 | 2 | import dbs |
3 | import os | 3 | import os |
4 | import sys | 4 | import sys |
5 | 5 | ||
@@ -63,15 +63,12 @@ def get_user_market(config_path, user_id, debug=False): | |||
63 | if debug: | 63 | if debug: |
64 | args.append("--debug") | 64 | args.append("--debug") |
65 | args = parse_args(args) | 65 | args = parse_args(args) |
66 | pg_config, redis_config = parse_config(args) | 66 | parse_config(args) |
67 | market_id, market_config, user_id = list(fetch_markets(pg_config, str(user_id)))[0] | 67 | market_id, market_config, user_id = list(fetch_markets(str(user_id)))[0] |
68 | return market.Market.from_config(market_config, args, | 68 | return market.Market.from_config(market_config, args, user_id=user_id) |
69 | pg_config=pg_config, market_id=market_id, | ||
70 | user_id=user_id) | ||
71 | 69 | ||
72 | def fetch_markets(pg_config, user): | 70 | def fetch_markets(user): |
73 | connection = psycopg2.connect(**pg_config) | 71 | cursor = dbs.psql.cursor() |
74 | cursor = connection.cursor() | ||
75 | 72 | ||
76 | if user is None: | 73 | if user is None: |
77 | cursor.execute("SELECT id,config,user_id FROM market_configs") | 74 | cursor.execute("SELECT id,config,user_id FROM market_configs") |
@@ -82,30 +79,11 @@ def fetch_markets(pg_config, user): | |||
82 | yield row | 79 | yield row |
83 | 80 | ||
84 | def parse_config(args): | 81 | def parse_config(args): |
85 | pg_config = { | 82 | if args.db_host is not None: |
86 | "host": args.db_host, | 83 | dbs.connect_psql(args) |
87 | "port": args.db_port, | 84 | |
88 | "user": args.db_user, | 85 | if args.redis_host is not None: |
89 | "password": args.db_password, | 86 | dbs.connect_redis(args) |
90 | "database": args.db_database, | ||
91 | } | ||
92 | del(args.db_host) | ||
93 | del(args.db_port) | ||
94 | del(args.db_user) | ||
95 | del(args.db_password) | ||
96 | del(args.db_database) | ||
97 | |||
98 | redis_config = { | ||
99 | "host": args.redis_host, | ||
100 | "port": args.redis_port, | ||
101 | "db": args.redis_database, | ||
102 | } | ||
103 | if redis_config["host"].startswith("/"): | ||
104 | redis_config["unix_socket_path"] = redis_config.pop("host") | ||
105 | del(redis_config["port"]) | ||
106 | del(args.redis_host) | ||
107 | del(args.redis_port) | ||
108 | del(args.redis_database) | ||
109 | 87 | ||
110 | report_path = args.report_path | 88 | report_path = args.report_path |
111 | 89 | ||
@@ -113,8 +91,6 @@ def parse_config(args): | |||
113 | os.path.exists(report_path): | 91 | os.path.exists(report_path): |
114 | os.makedirs(report_path) | 92 | os.makedirs(report_path) |
115 | 93 | ||
116 | return pg_config, redis_config | ||
117 | |||
118 | def parse_args(argv): | 94 | def parse_args(argv): |
119 | parser = configargparse.ArgumentParser( | 95 | parser = configargparse.ArgumentParser( |
120 | description="Run the trade bot.") | 96 | description="Run the trade bot.") |
@@ -176,11 +152,10 @@ def parse_args(argv): | |||
176 | parsed.action = ["sell_all"] | 152 | parsed.action = ["sell_all"] |
177 | return parsed | 153 | return parsed |
178 | 154 | ||
179 | def process(market_config, market_id, user_id, args, pg_config, redis_config): | 155 | def process(market_config, market_id, user_id, args): |
180 | try: | 156 | try: |
181 | market.Market\ | 157 | market.Market\ |
182 | .from_config(market_config, args, market_id=market_id, | 158 | .from_config(market_config, args, market_id=market_id, |
183 | pg_config=pg_config, redis_config=redis_config, | ||
184 | user_id=user_id)\ | 159 | user_id=user_id)\ |
185 | .process(args.action, before=args.before, after=args.after) | 160 | .process(args.action, before=args.before, after=args.after) |
186 | except Exception as e: | 161 | except Exception as e: |
@@ -189,7 +164,7 @@ def process(market_config, market_id, user_id, args, pg_config, redis_config): | |||
189 | def main(argv): | 164 | def main(argv): |
190 | args = parse_args(argv) | 165 | args = parse_args(argv) |
191 | 166 | ||
192 | pg_config, redis_config = parse_config(args) | 167 | parse_config(args) |
193 | 168 | ||
194 | market.Portfolio.report.set_verbose(not args.quiet) | 169 | market.Portfolio.report.set_verbose(not args.quiet) |
195 | 170 | ||
@@ -205,8 +180,8 @@ def main(argv): | |||
205 | else: | 180 | else: |
206 | process_ = process | 181 | process_ = process |
207 | 182 | ||
208 | for market_id, market_config, user_id in fetch_markets(pg_config, args.user): | 183 | for market_id, market_config, user_id in fetch_markets(args.user): |
209 | process_(market_config, market_id, user_id, args, pg_config, redis_config) | 184 | process_(market_config, market_id, user_id, args) |
210 | 185 | ||
211 | if args.parallel: | 186 | if args.parallel: |
212 | for thread in threads: | 187 | for thread in threads: |
@@ -1,8 +1,7 @@ | |||
1 | from ccxt import ExchangeError, NotSupported, RequestTimeout, InvalidNonce | 1 | from ccxt import ExchangeError, NotSupported, RequestTimeout, InvalidNonce |
2 | import ccxt_wrapper as ccxt | 2 | import ccxt_wrapper as ccxt |
3 | import time | 3 | import time |
4 | import psycopg2 | 4 | import dbs |
5 | import redis | ||
6 | from store import * | 5 | from store import * |
7 | from cachetools.func import ttl_cache | 6 | from cachetools.func import ttl_cache |
8 | from datetime import datetime | 7 | from datetime import datetime |
@@ -27,7 +26,7 @@ class Market: | |||
27 | self.balances = BalanceStore(self) | 26 | self.balances = BalanceStore(self) |
28 | self.processor = Processor(self) | 27 | self.processor = Processor(self) |
29 | 28 | ||
30 | for key in ["user_id", "market_id", "pg_config", "redis_config"]: | 29 | for key in ["user_id", "market_id"]: |
31 | setattr(self, key, kwargs.get(key, None)) | 30 | setattr(self, key, kwargs.get(key, None)) |
32 | 31 | ||
33 | self.report.log_market(self.args) | 32 | self.report.log_market(self.args) |
@@ -45,9 +44,9 @@ class Market: | |||
45 | date = datetime.datetime.now() | 44 | date = datetime.datetime.now() |
46 | if self.args.report_path is not None: | 45 | if self.args.report_path is not None: |
47 | self.store_file_report(date) | 46 | self.store_file_report(date) |
48 | if self.pg_config is not None and self.args.report_db: | 47 | if dbs.psql_connected() and self.args.report_db: |
49 | self.store_database_report(date) | 48 | self.store_database_report(date) |
50 | if self.redis_config is not None and self.args.report_redis: | 49 | if dbs.redis_connected() and self.args.report_redis: |
51 | self.store_redis_report(date) | 50 | self.store_redis_report(date) |
52 | 51 | ||
53 | def store_file_report(self, date): | 52 | def store_file_report(self, date): |
@@ -64,29 +63,26 @@ class Market: | |||
64 | try: | 63 | try: |
65 | report_query = 'INSERT INTO reports("date", "market_config_id", "debug") VALUES (%s, %s, %s) RETURNING id;' | 64 | report_query = 'INSERT INTO reports("date", "market_config_id", "debug") VALUES (%s, %s, %s) RETURNING id;' |
66 | line_query = 'INSERT INTO report_lines("date", "report_id", "type", "payload") VALUES (%s, %s, %s, %s);' | 65 | line_query = 'INSERT INTO report_lines("date", "report_id", "type", "payload") VALUES (%s, %s, %s, %s);' |
67 | connection = psycopg2.connect(**self.pg_config) | 66 | cursor = dbs.psql.cursor() |
68 | cursor = connection.cursor() | ||
69 | cursor.execute(report_query, (date, self.market_id, self.debug)) | 67 | cursor.execute(report_query, (date, self.market_id, self.debug)) |
70 | report_id = cursor.fetchone()[0] | 68 | report_id = cursor.fetchone()[0] |
71 | for date, type_, payload in self.report.to_json_array(): | 69 | for date, type_, payload in self.report.to_json_array(): |
72 | cursor.execute(line_query, (date, report_id, type_, payload)) | 70 | cursor.execute(line_query, (date, report_id, type_, payload)) |
73 | 71 | ||
74 | connection.commit() | 72 | dbs.psql.commit() |
75 | cursor.close() | 73 | cursor.close() |
76 | connection.close() | ||
77 | except Exception as e: | 74 | except Exception as e: |
78 | print("impossible to store report to database: {}; {}".format(e.__class__.__name__, e)) | 75 | print("impossible to store report to database: {}; {}".format(e.__class__.__name__, e)) |
79 | 76 | ||
80 | def store_redis_report(self, date): | 77 | def store_redis_report(self, date): |
81 | try: | 78 | try: |
82 | conn = redis.Redis(**self.redis_config) | ||
83 | for type_, log in self.report.to_json_redis(): | 79 | for type_, log in self.report.to_json_redis(): |
84 | key = "/cryptoportfolio/{}/{}/{}".format(self.market_id, date.isoformat(), type_) | 80 | key = "/cryptoportfolio/{}/{}/{}".format(self.market_id, date.isoformat(), type_) |
85 | conn.set(key, log, ex=31*24*60*60) | 81 | dbs.redis.set(key, log, ex=31*24*60*60) |
86 | key = "/cryptoportfolio/{}/latest/{}".format(self.market_id, type_) | 82 | key = "/cryptoportfolio/{}/latest/{}".format(self.market_id, type_) |
87 | conn.set(key, log) | 83 | dbs.redis.set(key, log) |
88 | key = "/cryptoportfolio/{}/latest/date".format(self.market_id) | 84 | key = "/cryptoportfolio/{}/latest/date".format(self.market_id) |
89 | conn.set(key, date.isoformat()) | 85 | dbs.redis.set(key, date.isoformat()) |
90 | except Exception as e: | 86 | except Exception as e: |
91 | print("impossible to store report to redis: {}; {}".format(e.__class__.__name__, e)) | 87 | print("impossible to store report to redis: {}; {}".format(e.__class__.__name__, e)) |
92 | 88 | ||
@@ -259,6 +255,7 @@ class Processor: | |||
259 | "name": "print_balances", | 255 | "name": "print_balances", |
260 | "number": 1, | 256 | "number": 1, |
261 | "fetch_balances": ["begin"], | 257 | "fetch_balances": ["begin"], |
258 | "fetch_balances_args": { "add_portfolio": True }, | ||
262 | "print_tickers": { "base_currency": "BTC" }, | 259 | "print_tickers": { "base_currency": "BTC" }, |
263 | } | 260 | } |
264 | ], | 261 | ], |
@@ -390,15 +387,19 @@ class Processor: | |||
390 | def process_step(self, scenario_name, step, kwargs): | 387 | def process_step(self, scenario_name, step, kwargs): |
391 | process_name = "process_{}__{}_{}".format(scenario_name, step["number"], step["name"]) | 388 | process_name = "process_{}__{}_{}".format(scenario_name, step["number"], step["name"]) |
392 | self.market.report.log_stage("{}_begin".format(process_name)) | 389 | self.market.report.log_stage("{}_begin".format(process_name)) |
390 | |||
391 | fetch_args = step.get("fetch_balances_args", {}) | ||
393 | if "begin" in step.get("fetch_balances", []): | 392 | if "begin" in step.get("fetch_balances", []): |
394 | self.market.balances.fetch_balances(tag="{}_begin".format(process_name), log_tickers=True) | 393 | self.market.balances.fetch_balances(tag="{}_begin".format(process_name), |
394 | log_tickers=True, **fetch_args) | ||
395 | 395 | ||
396 | for action in self.ordered_actions: | 396 | for action in self.ordered_actions: |
397 | if action in step: | 397 | if action in step: |
398 | self.run_action(action, step[action], kwargs) | 398 | self.run_action(action, step[action], kwargs) |
399 | 399 | ||
400 | if "end" in step.get("fetch_balances", []): | 400 | if "end" in step.get("fetch_balances", []): |
401 | self.market.balances.fetch_balances(tag="{}_end".format(process_name), log_tickers=True) | 401 | self.market.balances.fetch_balances(tag="{}_end".format(process_name), |
402 | log_tickers=True, **fetch_args) | ||
402 | self.market.report.log_stage("{}_end".format(process_name)) | 403 | self.market.report.log_stage("{}_end".format(process_name)) |
403 | 404 | ||
404 | def method_arguments(self, action): | 405 | def method_arguments(self, action): |
@@ -7,6 +7,7 @@ import datetime | |||
7 | import inspect | 7 | import inspect |
8 | from json import JSONDecodeError | 8 | from json import JSONDecodeError |
9 | from simplejson.errors import JSONDecodeError as SimpleJSONDecodeError | 9 | from simplejson.errors import JSONDecodeError as SimpleJSONDecodeError |
10 | import dbs | ||
10 | 11 | ||
11 | __all__ = ["Portfolio", "BalanceStore", "ReportStore", "TradeStore"] | 12 | __all__ = ["Portfolio", "BalanceStore", "ReportStore", "TradeStore"] |
12 | 13 | ||
@@ -302,13 +303,16 @@ class BalanceStore: | |||
302 | compute_value, type) | 303 | compute_value, type) |
303 | return amounts | 304 | return amounts |
304 | 305 | ||
305 | def fetch_balances(self, tag=None, log_tickers=False, | 306 | def fetch_balances(self, tag=None, add_portfolio=False, log_tickers=False, |
306 | ticker_currency="BTC", ticker_compute_value="average", ticker_type="total"): | 307 | ticker_currency="BTC", ticker_compute_value="average", ticker_type="total"): |
307 | all_balances = self.market.ccxt.fetch_all_balances() | 308 | all_balances = self.market.ccxt.fetch_all_balances() |
308 | for currency, balance in all_balances.items(): | 309 | for currency, balance in all_balances.items(): |
309 | if balance["exchange_total"] != 0 or balance["margin_total"] != 0 or \ | 310 | if balance["exchange_total"] != 0 or balance["margin_total"] != 0 or \ |
310 | currency in self.all: | 311 | currency in self.all: |
311 | self.all[currency] = portfolio.Balance(currency, balance) | 312 | self.all[currency] = portfolio.Balance(currency, balance) |
313 | if add_portfolio: | ||
314 | for currency in Portfolio.repartition(from_cache=True): | ||
315 | self.all.setdefault(currency, portfolio.Balance(currency, {})) | ||
312 | if log_tickers: | 316 | if log_tickers: |
313 | tickers = self.in_currency(ticker_currency, compute_value=ticker_compute_value, type=ticker_type) | 317 | tickers = self.in_currency(ticker_currency, compute_value=ticker_compute_value, type=ticker_type) |
314 | self.market.report.log_balances(tag=tag, | 318 | self.market.report.log_balances(tag=tag, |
@@ -508,7 +512,9 @@ class Portfolio: | |||
508 | cls.get_cryptoportfolio(refetch=True) | 512 | cls.get_cryptoportfolio(refetch=True) |
509 | 513 | ||
510 | @classmethod | 514 | @classmethod |
511 | def repartition(cls, liquidity="medium"): | 515 | def repartition(cls, liquidity="medium", from_cache=False): |
516 | if from_cache: | ||
517 | cls.retrieve_cryptoportfolio() | ||
512 | cls.get_cryptoportfolio() | 518 | cls.get_cryptoportfolio() |
513 | liquidities = cls.liquidities.get(liquidity) | 519 | liquidities = cls.liquidities.get(liquidity) |
514 | return liquidities[cls.last_date.get()] | 520 | return liquidities[cls.last_date.get()] |
@@ -530,12 +536,39 @@ class Portfolio: | |||
530 | try: | 536 | try: |
531 | cls.data.set(r.json(parse_int=D, parse_float=D)) | 537 | cls.data.set(r.json(parse_int=D, parse_float=D)) |
532 | cls.parse_cryptoportfolio() | 538 | cls.parse_cryptoportfolio() |
539 | cls.store_cryptoportfolio() | ||
533 | except (JSONDecodeError, SimpleJSONDecodeError): | 540 | except (JSONDecodeError, SimpleJSONDecodeError): |
534 | cls.data.set(None) | 541 | cls.data.set(None) |
535 | cls.last_date.set(None) | 542 | cls.last_date.set(None) |
536 | cls.liquidities.set({}) | 543 | cls.liquidities.set({}) |
537 | 544 | ||
538 | @classmethod | 545 | @classmethod |
546 | def retrieve_cryptoportfolio(cls): | ||
547 | if dbs.redis_connected(): | ||
548 | repartition = dbs.redis.get("/cryptoportfolio/repartition/latest") | ||
549 | date = dbs.redis.get("/cryptoportfolio/repartition/date") | ||
550 | if date is not None and repartition is not None: | ||
551 | date = datetime.datetime.strptime(date.decode(), "%Y-%m-%d") | ||
552 | repartition = json.loads(repartition, parse_int=D, parse_float=D) | ||
553 | repartition = { k: { date: v } for k, v in repartition.items() } | ||
554 | |||
555 | cls.data.set("") | ||
556 | cls.last_date.set(date) | ||
557 | cls.liquidities.set(repartition) | ||
558 | |||
559 | @classmethod | ||
560 | def store_cryptoportfolio(cls): | ||
561 | if dbs.redis_connected(): | ||
562 | hash_ = {} | ||
563 | for liquidity, repartitions in cls.liquidities.items(): | ||
564 | hash_[liquidity] = repartitions[cls.last_date.get()] | ||
565 | dump = json.dumps(hash_) | ||
566 | key = "/cryptoportfolio/repartition/latest" | ||
567 | dbs.redis.set(key, dump) | ||
568 | key = "/cryptoportfolio/repartition/date" | ||
569 | dbs.redis.set(key, cls.last_date.date().isoformat()) | ||
570 | |||
571 | @classmethod | ||
539 | def parse_cryptoportfolio(cls): | 572 | def parse_cryptoportfolio(cls): |
540 | def filter_weights(weight_hash): | 573 | def filter_weights(weight_hash): |
541 | if weight_hash[1][0] == 0: | 574 | if weight_hash[1][0] == 0: |
@@ -9,6 +9,7 @@ if "unit" in limits: | |||
9 | from tests.test_market import * | 9 | from tests.test_market import * |
10 | from tests.test_store import * | 10 | from tests.test_store import * |
11 | from tests.test_portfolio import * | 11 | from tests.test_portfolio import * |
12 | from tests.test_dbs import * | ||
12 | 13 | ||
13 | if "acceptance" in limits: | 14 | if "acceptance" in limits: |
14 | from tests.test_acceptance import * | 15 | from tests.test_acceptance import * |
diff --git a/tests/helper.py b/tests/helper.py index b85bf3a..935e060 100644 --- a/tests/helper.py +++ b/tests/helper.py | |||
@@ -4,7 +4,7 @@ from decimal import Decimal as D | |||
4 | from unittest import mock | 4 | from unittest import mock |
5 | import requests_mock | 5 | import requests_mock |
6 | from io import StringIO | 6 | from io import StringIO |
7 | import portfolio, market, main, store | 7 | import portfolio, market, main, store, dbs |
8 | 8 | ||
9 | __all__ = ["limits", "unittest", "WebMockTestCase", "mock", "D", | 9 | __all__ = ["limits", "unittest", "WebMockTestCase", "mock", "D", |
10 | "StringIO"] | 10 | "StringIO"] |
@@ -48,6 +48,8 @@ class WebMockTestCase(unittest.TestCase): | |||
48 | callback=None), | 48 | callback=None), |
49 | mock.patch.multiple(portfolio.Computation, | 49 | mock.patch.multiple(portfolio.Computation, |
50 | computations=portfolio.Computation.computations), | 50 | computations=portfolio.Computation.computations), |
51 | mock.patch.multiple(dbs, | ||
52 | redis=None, psql=None) | ||
51 | ] | 53 | ] |
52 | for patcher in self.patchers: | 54 | for patcher in self.patchers: |
53 | patcher.start() | 55 | patcher.start() |
diff --git a/tests/test_dbs.py b/tests/test_dbs.py new file mode 100644 index 0000000..157c423 --- /dev/null +++ b/tests/test_dbs.py | |||
@@ -0,0 +1,108 @@ | |||
1 | from .helper import * | ||
2 | import dbs, main | ||
3 | |||
4 | @unittest.skipUnless("unit" in limits, "Unit skipped") | ||
5 | class DbsTest(WebMockTestCase): | ||
6 | @mock.patch.object(dbs, "psycopg2") | ||
7 | def test_connect_psql(self, psycopg2): | ||
8 | args = main.configargparse.Namespace(**{ | ||
9 | "db_host": "host", | ||
10 | "db_port": "port", | ||
11 | "db_user": "user", | ||
12 | "db_password": "password", | ||
13 | "db_database": "database", | ||
14 | }) | ||
15 | psycopg2.connect.return_value = "pg_connection" | ||
16 | dbs.connect_psql(args) | ||
17 | |||
18 | psycopg2.connect.assert_called_once_with(host="host", | ||
19 | port="port", user="user", password="password", | ||
20 | database="database") | ||
21 | self.assertEqual("pg_connection", dbs.psql) | ||
22 | with self.assertRaises(AttributeError): | ||
23 | args.db_password | ||
24 | |||
25 | psycopg2.connect.reset_mock() | ||
26 | args = main.configargparse.Namespace(**{ | ||
27 | "db_host": "host", | ||
28 | "db_port": "port", | ||
29 | "db_user": "user", | ||
30 | "db_password": "password", | ||
31 | "db_database": "database", | ||
32 | }) | ||
33 | dbs.connect_psql(args) | ||
34 | psycopg2.connect.assert_not_called() | ||
35 | |||
36 | @mock.patch.object(dbs, "_redis") | ||
37 | def test_connect_redis(self, redis): | ||
38 | with self.subTest(redis_host="tcp"): | ||
39 | args = main.configargparse.Namespace(**{ | ||
40 | "redis_host": "host", | ||
41 | "redis_port": "port", | ||
42 | "redis_database": "database", | ||
43 | }) | ||
44 | redis.Redis.return_value = "redis_connection" | ||
45 | dbs.connect_redis(args) | ||
46 | |||
47 | redis.Redis.assert_called_once_with(host="host", | ||
48 | port="port", db="database") | ||
49 | self.assertEqual("redis_connection", dbs.redis) | ||
50 | with self.assertRaises(AttributeError): | ||
51 | args.redis_database | ||
52 | |||
53 | redis.Redis.reset_mock() | ||
54 | args = main.configargparse.Namespace(**{ | ||
55 | "redis_host": "host", | ||
56 | "redis_port": "port", | ||
57 | "redis_database": "database", | ||
58 | }) | ||
59 | dbs.connect_redis(args) | ||
60 | redis.Redis.assert_not_called() | ||
61 | |||
62 | dbs.redis = None | ||
63 | with self.subTest(redis_host="socket"): | ||
64 | args = main.configargparse.Namespace(**{ | ||
65 | "redis_host": "/run/foo", | ||
66 | "redis_port": "port", | ||
67 | "redis_database": "database", | ||
68 | }) | ||
69 | redis.Redis.return_value = "redis_socket" | ||
70 | dbs.connect_redis(args) | ||
71 | |||
72 | redis.Redis.assert_called_once_with(unix_socket_path="/run/foo", db="database") | ||
73 | self.assertEqual("redis_socket", dbs.redis) | ||
74 | |||
75 | def test_redis_connected(self): | ||
76 | with self.subTest(redis=None): | ||
77 | dbs.redis = None | ||
78 | self.assertFalse(dbs.redis_connected()) | ||
79 | |||
80 | with self.subTest(redis="mocked_true"): | ||
81 | dbs.redis = mock.Mock() | ||
82 | dbs.redis.ping.return_value = True | ||
83 | self.assertTrue(dbs.redis_connected()) | ||
84 | |||
85 | with self.subTest(redis="mocked_false"): | ||
86 | dbs.redis = mock.Mock() | ||
87 | dbs.redis.ping.return_value = False | ||
88 | self.assertFalse(dbs.redis_connected()) | ||
89 | |||
90 | with self.subTest(redis="mocked_raise"): | ||
91 | dbs.redis = mock.Mock() | ||
92 | dbs.redis.ping.side_effect = Exception("bouh") | ||
93 | self.assertFalse(dbs.redis_connected()) | ||
94 | |||
95 | def test_psql_connected(self): | ||
96 | with self.subTest(psql=None): | ||
97 | dbs.psql = None | ||
98 | self.assertFalse(dbs.psql_connected()) | ||
99 | |||
100 | with self.subTest(psql="connected"): | ||
101 | dbs.psql = mock.Mock() | ||
102 | dbs.psql.closed = 0 | ||
103 | self.assertTrue(dbs.psql_connected()) | ||
104 | |||
105 | with self.subTest(psql="not connected"): | ||
106 | dbs.psql = mock.Mock() | ||
107 | dbs.psql.closed = 3 | ||
108 | self.assertFalse(dbs.psql_connected()) | ||
diff --git a/tests/test_main.py b/tests/test_main.py index 55b1382..1864c06 100644 --- a/tests/test_main.py +++ b/tests/test_main.py | |||
@@ -103,7 +103,6 @@ class MainTest(WebMockTestCase): | |||
103 | mock.patch("main.parse_config") as main_parse_config: | 103 | mock.patch("main.parse_config") as main_parse_config: |
104 | with self.subTest(debug=False): | 104 | with self.subTest(debug=False): |
105 | main_parse_args.return_value = self.market_args() | 105 | main_parse_args.return_value = self.market_args() |
106 | main_parse_config.return_value = ["pg_config", "redis_config"] | ||
107 | main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)] | 106 | main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)] |
108 | m = main.get_user_market("config_path.ini", 1) | 107 | m = main.get_user_market("config_path.ini", 1) |
109 | 108 | ||
@@ -114,7 +113,6 @@ class MainTest(WebMockTestCase): | |||
114 | main_parse_args.reset_mock() | 113 | main_parse_args.reset_mock() |
115 | with self.subTest(debug=True): | 114 | with self.subTest(debug=True): |
116 | main_parse_args.return_value = self.market_args(debug=True) | 115 | main_parse_args.return_value = self.market_args(debug=True) |
117 | main_parse_config.return_value = ["pg_config", "redis_config"] | ||
118 | main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)] | 116 | main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)] |
119 | m = main.get_user_market("config_path.ini", 1, debug=True) | 117 | m = main.get_user_market("config_path.ini", 1, debug=True) |
120 | 118 | ||
@@ -135,16 +133,16 @@ class MainTest(WebMockTestCase): | |||
135 | args_mock.after = "after" | 133 | args_mock.after = "after" |
136 | self.assertEqual("", stdout_mock.getvalue()) | 134 | self.assertEqual("", stdout_mock.getvalue()) |
137 | 135 | ||
138 | main.process("config", 3, 1, args_mock, "pg_config", "redis_config") | 136 | main.process("config", 3, 1, args_mock) |
139 | 137 | ||
140 | market_mock.from_config.assert_has_calls([ | 138 | market_mock.from_config.assert_has_calls([ |
141 | mock.call("config", args_mock, pg_config="pg_config", redis_config="redis_config", market_id=3, user_id=1), | 139 | mock.call("config", args_mock, market_id=3, user_id=1), |
142 | mock.call().process("action", before="before", after="after"), | 140 | mock.call().process("action", before="before", after="after"), |
143 | ]) | 141 | ]) |
144 | 142 | ||
145 | with self.subTest(exception=True): | 143 | with self.subTest(exception=True): |
146 | market_mock.from_config.side_effect = Exception("boo") | 144 | market_mock.from_config.side_effect = Exception("boo") |
147 | main.process(3, "config", 1, args_mock, "pg_config", "redis_config") | 145 | main.process(3, "config", 1, args_mock) |
148 | self.assertEqual("Exception: boo\n", stdout_mock.getvalue()) | 146 | self.assertEqual("Exception: boo\n", stdout_mock.getvalue()) |
149 | 147 | ||
150 | def test_main(self): | 148 | def test_main(self): |
@@ -159,20 +157,18 @@ class MainTest(WebMockTestCase): | |||
159 | args_mock.user = "user" | 157 | args_mock.user = "user" |
160 | parse_args.return_value = args_mock | 158 | parse_args.return_value = args_mock |
161 | 159 | ||
162 | parse_config.return_value = ["pg_config", "redis_config"] | ||
163 | |||
164 | fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] | 160 | fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] |
165 | 161 | ||
166 | main.main(["Foo", "Bar"]) | 162 | main.main(["Foo", "Bar"]) |
167 | 163 | ||
168 | parse_args.assert_called_with(["Foo", "Bar"]) | 164 | parse_args.assert_called_with(["Foo", "Bar"]) |
169 | parse_config.assert_called_with(args_mock) | 165 | parse_config.assert_called_with(args_mock) |
170 | fetch_markets.assert_called_with("pg_config", "user") | 166 | fetch_markets.assert_called_with("user") |
171 | 167 | ||
172 | self.assertEqual(2, process.call_count) | 168 | self.assertEqual(2, process.call_count) |
173 | process.assert_has_calls([ | 169 | process.assert_has_calls([ |
174 | mock.call("config1", 3, 1, args_mock, "pg_config", "redis_config"), | 170 | mock.call("config1", 3, 1, args_mock), |
175 | mock.call("config2", 1, 2, args_mock, "pg_config", "redis_config"), | 171 | mock.call("config2", 1, 2, args_mock), |
176 | ]) | 172 | ]) |
177 | with self.subTest(parallel=True): | 173 | with self.subTest(parallel=True): |
178 | with mock.patch("main.parse_args") as parse_args,\ | 174 | with mock.patch("main.parse_args") as parse_args,\ |
@@ -187,24 +183,22 @@ class MainTest(WebMockTestCase): | |||
187 | args_mock.user = "user" | 183 | args_mock.user = "user" |
188 | parse_args.return_value = args_mock | 184 | parse_args.return_value = args_mock |
189 | 185 | ||
190 | parse_config.return_value = ["pg_config", "redis_config"] | ||
191 | |||
192 | fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] | 186 | fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] |
193 | 187 | ||
194 | main.main(["Foo", "Bar"]) | 188 | main.main(["Foo", "Bar"]) |
195 | 189 | ||
196 | parse_args.assert_called_with(["Foo", "Bar"]) | 190 | parse_args.assert_called_with(["Foo", "Bar"]) |
197 | parse_config.assert_called_with(args_mock) | 191 | parse_config.assert_called_with(args_mock) |
198 | fetch_markets.assert_called_with("pg_config", "user") | 192 | fetch_markets.assert_called_with("user") |
199 | 193 | ||
200 | stop.assert_called_once_with() | 194 | stop.assert_called_once_with() |
201 | start.assert_called_once_with() | 195 | start.assert_called_once_with() |
202 | self.assertEqual(2, process.call_count) | 196 | self.assertEqual(2, process.call_count) |
203 | process.assert_has_calls([ | 197 | process.assert_has_calls([ |
204 | mock.call.__bool__(), | 198 | mock.call.__bool__(), |
205 | mock.call("config1", 3, 1, args_mock, "pg_config", "redis_config"), | 199 | mock.call("config1", 3, 1, args_mock), |
206 | mock.call.__bool__(), | 200 | mock.call.__bool__(), |
207 | mock.call("config2", 1, 2, args_mock, "pg_config", "redis_config"), | 201 | mock.call("config2", 1, 2, args_mock), |
208 | ]) | 202 | ]) |
209 | with self.subTest(quiet=True): | 203 | with self.subTest(quiet=True): |
210 | with mock.patch("main.parse_args") as parse_args,\ | 204 | with mock.patch("main.parse_args") as parse_args,\ |
@@ -219,8 +213,6 @@ class MainTest(WebMockTestCase): | |||
219 | args_mock.user = "user" | 213 | args_mock.user = "user" |
220 | parse_args.return_value = args_mock | 214 | parse_args.return_value = args_mock |
221 | 215 | ||
222 | parse_config.return_value = ["pg_config", "redis_config"] | ||
223 | |||
224 | fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] | 216 | fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] |
225 | 217 | ||
226 | main.main(["Foo", "Bar"]) | 218 | main.main(["Foo", "Bar"]) |
@@ -240,8 +232,6 @@ class MainTest(WebMockTestCase): | |||
240 | args_mock.user = "user" | 232 | args_mock.user = "user" |
241 | parse_args.return_value = args_mock | 233 | parse_args.return_value = args_mock |
242 | 234 | ||
243 | parse_config.return_value = ["pg_config", "redis_config"] | ||
244 | |||
245 | fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] | 235 | fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] |
246 | 236 | ||
247 | main.main(["Foo", "Bar"]) | 237 | main.main(["Foo", "Bar"]) |
@@ -252,63 +242,57 @@ class MainTest(WebMockTestCase): | |||
252 | @mock.patch.object(main.sys, "exit") | 242 | @mock.patch.object(main.sys, "exit") |
253 | @mock.patch("main.os") | 243 | @mock.patch("main.os") |
254 | def test_parse_config(self, os, exit): | 244 | def test_parse_config(self, os, exit): |
255 | with self.subTest(report_path=None): | 245 | with self.subTest(report_path=None),\ |
246 | mock.patch.object(main.dbs, "connect_psql") as psql,\ | ||
247 | mock.patch.object(main.dbs, "connect_redis") as redis: | ||
256 | args = main.configargparse.Namespace(**{ | 248 | args = main.configargparse.Namespace(**{ |
257 | "db_host": "host", | 249 | "db_host": "host", |
258 | "db_port": "port", | ||
259 | "db_user": "user", | ||
260 | "db_password": "password", | ||
261 | "db_database": "database", | ||
262 | "redis_host": "rhost", | 250 | "redis_host": "rhost", |
263 | "redis_port": "rport", | ||
264 | "redis_database": "rdb", | ||
265 | "report_path": None, | 251 | "report_path": None, |
266 | }) | 252 | }) |
267 | 253 | ||
268 | db_config, redis_config = main.parse_config(args) | 254 | main.parse_config(args) |
269 | self.assertEqual({ "host": "host", "port": "port", "user": | 255 | psql.assert_called_once_with(args) |
270 | "user", "password": "password", "database": "database" | 256 | redis.assert_called_once_with(args) |
271 | }, db_config) | 257 | |
272 | self.assertEqual({ "host": "rhost", "port": "rport", "db": | 258 | with self.subTest(report_path=None, db=None),\ |
273 | "rdb"}, redis_config) | 259 | mock.patch.object(main.dbs, "connect_psql") as psql,\ |
260 | mock.patch.object(main.dbs, "connect_redis") as redis: | ||
261 | args = main.configargparse.Namespace(**{ | ||
262 | "db_host": None, | ||
263 | "redis_host": "rhost", | ||
264 | "report_path": None, | ||
265 | }) | ||
274 | 266 | ||
275 | with self.assertRaises(AttributeError): | 267 | main.parse_config(args) |
276 | args.db_password | 268 | psql.assert_not_called() |
277 | with self.assertRaises(AttributeError): | 269 | redis.assert_called_once_with(args) |
278 | args.redis_host | ||
279 | 270 | ||
280 | with self.subTest(redis_host="socket"): | 271 | with self.subTest(report_path=None, redis=None),\ |
272 | mock.patch.object(main.dbs, "connect_psql") as psql,\ | ||
273 | mock.patch.object(main.dbs, "connect_redis") as redis: | ||
281 | args = main.configargparse.Namespace(**{ | 274 | args = main.configargparse.Namespace(**{ |
282 | "db_host": "host", | 275 | "db_host": "host", |
283 | "db_port": "port", | 276 | "redis_host": None, |
284 | "db_user": "user", | ||
285 | "db_password": "password", | ||
286 | "db_database": "database", | ||
287 | "redis_host": "/run/foo", | ||
288 | "redis_port": "rport", | ||
289 | "redis_database": "rdb", | ||
290 | "report_path": None, | 277 | "report_path": None, |
291 | }) | 278 | }) |
292 | 279 | ||
293 | db_config, redis_config = main.parse_config(args) | 280 | main.parse_config(args) |
294 | self.assertEqual({ "unix_socket_path": "/run/foo", "db": "rdb"}, redis_config) | 281 | redis.assert_not_called() |
282 | psql.assert_called_once_with(args) | ||
295 | 283 | ||
296 | with self.subTest(report_path="present"): | 284 | with self.subTest(report_path="present"),\ |
285 | mock.patch.object(main.dbs, "connect_psql") as psql,\ | ||
286 | mock.patch.object(main.dbs, "connect_redis") as redis: | ||
297 | args = main.configargparse.Namespace(**{ | 287 | args = main.configargparse.Namespace(**{ |
298 | "db_host": "host", | 288 | "db_host": "host", |
299 | "db_port": "port", | ||
300 | "db_user": "user", | ||
301 | "db_password": "password", | ||
302 | "db_database": "database", | ||
303 | "redis_host": "rhost", | 289 | "redis_host": "rhost", |
304 | "redis_port": "rport", | ||
305 | "redis_database": "rdb", | ||
306 | "report_path": "report_path", | 290 | "report_path": "report_path", |
307 | }) | 291 | }) |
308 | 292 | ||
309 | os.path.exists.return_value = False | 293 | os.path.exists.return_value = False |
310 | 294 | ||
311 | result = main.parse_config(args) | 295 | main.parse_config(args) |
312 | 296 | ||
313 | os.path.exists.assert_called_once_with("report_path") | 297 | os.path.exists.assert_called_once_with("report_path") |
314 | os.makedirs.assert_called_once_with("report_path") | 298 | os.makedirs.assert_called_once_with("report_path") |
@@ -331,29 +315,24 @@ class MainTest(WebMockTestCase): | |||
331 | mock.patch('sys.stderr', new_callable=StringIO) as stdout_mock: | 315 | mock.patch('sys.stderr', new_callable=StringIO) as stdout_mock: |
332 | args = main.parse_args(["--config", "foo.bar"]) | 316 | args = main.parse_args(["--config", "foo.bar"]) |
333 | 317 | ||
334 | @mock.patch.object(main, "psycopg2") | 318 | @mock.patch.object(main.dbs, "psql") |
335 | def test_fetch_markets(self, psycopg2): | 319 | def test_fetch_markets(self, psql): |
336 | connect_mock = mock.Mock() | ||
337 | cursor_mock = mock.MagicMock() | 320 | cursor_mock = mock.MagicMock() |
338 | cursor_mock.__iter__.return_value = ["row_1", "row_2"] | 321 | cursor_mock.__iter__.return_value = ["row_1", "row_2"] |
339 | 322 | ||
340 | connect_mock.cursor.return_value = cursor_mock | 323 | psql.cursor.return_value = cursor_mock |
341 | psycopg2.connect.return_value = connect_mock | ||
342 | 324 | ||
343 | with self.subTest(user=None): | 325 | with self.subTest(user=None): |
344 | rows = list(main.fetch_markets({"foo": "bar"}, None)) | 326 | rows = list(main.fetch_markets(None)) |
345 | 327 | ||
346 | psycopg2.connect.assert_called_once_with(foo="bar") | ||
347 | cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs") | 328 | cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs") |
348 | 329 | ||
349 | self.assertEqual(["row_1", "row_2"], rows) | 330 | self.assertEqual(["row_1", "row_2"], rows) |
350 | 331 | ||
351 | psycopg2.connect.reset_mock() | ||
352 | cursor_mock.execute.reset_mock() | 332 | cursor_mock.execute.reset_mock() |
353 | with self.subTest(user=1): | 333 | with self.subTest(user=1): |
354 | rows = list(main.fetch_markets({"foo": "bar"}, 1)) | 334 | rows = list(main.fetch_markets(1)) |
355 | 335 | ||
356 | psycopg2.connect.assert_called_once_with(foo="bar") | ||
357 | cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs WHERE user_id = %s", 1) | 336 | cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs WHERE user_id = %s", 1) |
358 | 337 | ||
359 | self.assertEqual(["row_1", "row_2"], rows) | 338 | self.assertEqual(["row_1", "row_2"], rows) |
diff --git a/tests/test_market.py b/tests/test_market.py index 6a3322c..46fad53 100644 --- a/tests/test_market.py +++ b/tests/test_market.py | |||
@@ -1,5 +1,5 @@ | |||
1 | from .helper import * | 1 | from .helper import * |
2 | import market, store, portfolio | 2 | import market, store, portfolio, dbs |
3 | import datetime | 3 | import datetime |
4 | 4 | ||
5 | @unittest.skipUnless("unit" in limits, "Unit skipped") | 5 | @unittest.skipUnless("unit" in limits, "Unit skipped") |
@@ -595,13 +595,11 @@ class MarketTest(WebMockTestCase): | |||
595 | 595 | ||
596 | self.assertRegex(stdout_mock.getvalue(), "impossible to store report file: FileNotFoundError;") | 596 | self.assertRegex(stdout_mock.getvalue(), "impossible to store report file: FileNotFoundError;") |
597 | 597 | ||
598 | @mock.patch.object(market, "psycopg2") | 598 | @mock.patch.object(dbs, "psql") |
599 | def test_store_database_report(self, psycopg2): | 599 | def test_store_database_report(self, psql): |
600 | connect_mock = mock.Mock() | ||
601 | cursor_mock = mock.MagicMock() | 600 | cursor_mock = mock.MagicMock() |
602 | 601 | ||
603 | connect_mock.cursor.return_value = cursor_mock | 602 | psql.cursor.return_value = cursor_mock |
604 | psycopg2.connect.return_value = connect_mock | ||
605 | m = market.Market(self.ccxt, self.market_args(), | 603 | m = market.Market(self.ccxt, self.market_args(), |
606 | pg_config={"config": "pg_config"}, user_id=1) | 604 | pg_config={"config": "pg_config"}, user_id=1) |
607 | cursor_mock.fetchone.return_value = [42] | 605 | cursor_mock.fetchone.return_value = [42] |
@@ -613,7 +611,7 @@ class MarketTest(WebMockTestCase): | |||
613 | ("date2", "type2", "payload2"), | 611 | ("date2", "type2", "payload2"), |
614 | ] | 612 | ] |
615 | m.store_database_report(datetime.datetime(2018, 3, 24)) | 613 | m.store_database_report(datetime.datetime(2018, 3, 24)) |
616 | connect_mock.assert_has_calls([ | 614 | psql.assert_has_calls([ |
617 | mock.call.cursor(), | 615 | mock.call.cursor(), |
618 | mock.call.cursor().execute('INSERT INTO reports("date", "market_config_id", "debug") VALUES (%s, %s, %s) RETURNING id;', (datetime.datetime(2018, 3, 24), None, False)), | 616 | mock.call.cursor().execute('INSERT INTO reports("date", "market_config_id", "debug") VALUES (%s, %s, %s) RETURNING id;', (datetime.datetime(2018, 3, 24), None, False)), |
619 | mock.call.cursor().fetchone(), | 617 | mock.call.cursor().fetchone(), |
@@ -621,21 +619,16 @@ class MarketTest(WebMockTestCase): | |||
621 | mock.call.cursor().execute('INSERT INTO report_lines("date", "report_id", "type", "payload") VALUES (%s, %s, %s, %s);', ('date2', 42, 'type2', 'payload2')), | 619 | mock.call.cursor().execute('INSERT INTO report_lines("date", "report_id", "type", "payload") VALUES (%s, %s, %s, %s);', ('date2', 42, 'type2', 'payload2')), |
622 | mock.call.commit(), | 620 | mock.call.commit(), |
623 | mock.call.cursor().close(), | 621 | mock.call.cursor().close(), |
624 | mock.call.close() | ||
625 | ]) | 622 | ]) |
626 | 623 | ||
627 | connect_mock.reset_mock() | ||
628 | with self.subTest(error=True),\ | 624 | with self.subTest(error=True),\ |
629 | mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock: | 625 | mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock: |
630 | psycopg2.connect.side_effect = Exception("Bouh") | 626 | psql.cursor.side_effect = Exception("Bouh") |
631 | m.store_database_report(datetime.datetime(2018, 3, 24)) | 627 | m.store_database_report(datetime.datetime(2018, 3, 24)) |
632 | self.assertEqual(stdout_mock.getvalue(), "impossible to store report to database: Exception; Bouh\n") | 628 | self.assertEqual(stdout_mock.getvalue(), "impossible to store report to database: Exception; Bouh\n") |
633 | 629 | ||
634 | @mock.patch.object(market, "redis") | 630 | @mock.patch.object(dbs, "redis") |
635 | def test_store_redis_report(self, redis): | 631 | def test_store_redis_report(self, redis): |
636 | connect_mock = mock.Mock() | ||
637 | redis.Redis.return_value = connect_mock | ||
638 | |||
639 | m = market.Market(self.ccxt, self.market_args(), | 632 | m = market.Market(self.ccxt, self.market_args(), |
640 | redis_config={"config": "redis_config"}, market_id=1) | 633 | redis_config={"config": "redis_config"}, market_id=1) |
641 | 634 | ||
@@ -646,7 +639,7 @@ class MarketTest(WebMockTestCase): | |||
646 | ("type2", "payload2"), | 639 | ("type2", "payload2"), |
647 | ] | 640 | ] |
648 | m.store_redis_report(datetime.datetime(2018, 3, 24)) | 641 | m.store_redis_report(datetime.datetime(2018, 3, 24)) |
649 | connect_mock.assert_has_calls([ | 642 | redis.assert_has_calls([ |
650 | mock.call.set("/cryptoportfolio/1/2018-03-24T00:00:00/type1", "payload1", ex=31*24*60*60), | 643 | mock.call.set("/cryptoportfolio/1/2018-03-24T00:00:00/type1", "payload1", ex=31*24*60*60), |
651 | mock.call.set("/cryptoportfolio/1/latest/type1", "payload1"), | 644 | mock.call.set("/cryptoportfolio/1/latest/type1", "payload1"), |
652 | mock.call.set("/cryptoportfolio/1/2018-03-24T00:00:00/type2", "payload2", ex=31*24*60*60), | 645 | mock.call.set("/cryptoportfolio/1/2018-03-24T00:00:00/type2", "payload2", ex=31*24*60*60), |
@@ -654,20 +647,24 @@ class MarketTest(WebMockTestCase): | |||
654 | mock.call.set("/cryptoportfolio/1/latest/date", "2018-03-24T00:00:00"), | 647 | mock.call.set("/cryptoportfolio/1/latest/date", "2018-03-24T00:00:00"), |
655 | ]) | 648 | ]) |
656 | 649 | ||
657 | connect_mock.reset_mock() | 650 | redis.reset_mock() |
658 | with self.subTest(error=True),\ | 651 | with self.subTest(error=True),\ |
659 | mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock: | 652 | mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock: |
660 | redis.Redis.side_effect = Exception("Bouh") | 653 | redis.set.side_effect = Exception("Bouh") |
661 | m.store_redis_report(datetime.datetime(2018, 3, 24)) | 654 | m.store_redis_report(datetime.datetime(2018, 3, 24)) |
662 | self.assertEqual(stdout_mock.getvalue(), "impossible to store report to redis: Exception; Bouh\n") | 655 | self.assertEqual(stdout_mock.getvalue(), "impossible to store report to redis: Exception; Bouh\n") |
663 | 656 | ||
664 | def test_store_report(self): | 657 | def test_store_report(self): |
665 | m = market.Market(self.ccxt, self.market_args(report_db=False), user_id=1) | 658 | m = market.Market(self.ccxt, self.market_args(report_db=False), user_id=1) |
666 | with self.subTest(file=None, pg_config=None),\ | 659 | with self.subTest(file=None, pg_connected=None),\ |
660 | mock.patch.object(dbs, "psql_connected") as psql,\ | ||
661 | mock.patch.object(dbs, "redis_connected") as redis,\ | ||
667 | mock.patch.object(m, "report") as report,\ | 662 | mock.patch.object(m, "report") as report,\ |
668 | mock.patch.object(m, "store_database_report") as db_report,\ | 663 | mock.patch.object(m, "store_database_report") as db_report,\ |
669 | mock.patch.object(m, "store_redis_report") as redis_report,\ | 664 | mock.patch.object(m, "store_redis_report") as redis_report,\ |
670 | mock.patch.object(m, "store_file_report") as file_report: | 665 | mock.patch.object(m, "store_file_report") as file_report: |
666 | psql.return_value = False | ||
667 | redis.return_value = False | ||
671 | m.store_report() | 668 | m.store_report() |
672 | report.merge.assert_called_with(store.Portfolio.report) | 669 | report.merge.assert_called_with(store.Portfolio.report) |
673 | 670 | ||
@@ -677,13 +674,16 @@ class MarketTest(WebMockTestCase): | |||
677 | 674 | ||
678 | report.reset_mock() | 675 | report.reset_mock() |
679 | m = market.Market(self.ccxt, self.market_args(report_db=False, report_path="present"), user_id=1) | 676 | m = market.Market(self.ccxt, self.market_args(report_db=False, report_path="present"), user_id=1) |
680 | with self.subTest(file="present", pg_config=None),\ | 677 | with self.subTest(file="present", pg_connected=None),\ |
678 | mock.patch.object(dbs, "psql_connected") as psql,\ | ||
679 | mock.patch.object(dbs, "redis_connected") as redis,\ | ||
681 | mock.patch.object(m, "report") as report,\ | 680 | mock.patch.object(m, "report") as report,\ |
682 | mock.patch.object(m, "store_file_report") as file_report,\ | 681 | mock.patch.object(m, "store_file_report") as file_report,\ |
683 | mock.patch.object(m, "store_redis_report") as redis_report,\ | 682 | mock.patch.object(m, "store_redis_report") as redis_report,\ |
684 | mock.patch.object(m, "store_database_report") as db_report,\ | 683 | mock.patch.object(m, "store_database_report") as db_report,\ |
685 | mock.patch.object(market.datetime, "datetime") as time_mock: | 684 | mock.patch.object(market.datetime, "datetime") as time_mock: |
686 | 685 | psql.return_value = False | |
686 | redis.return_value = False | ||
687 | time_mock.now.return_value = datetime.datetime(2018, 2, 25) | 687 | time_mock.now.return_value = datetime.datetime(2018, 2, 25) |
688 | 688 | ||
689 | m.store_report() | 689 | m.store_report() |
@@ -695,13 +695,16 @@ class MarketTest(WebMockTestCase): | |||
695 | 695 | ||
696 | report.reset_mock() | 696 | report.reset_mock() |
697 | m = market.Market(self.ccxt, self.market_args(report_db=True, report_path="present"), user_id=1) | 697 | m = market.Market(self.ccxt, self.market_args(report_db=True, report_path="present"), user_id=1) |
698 | with self.subTest(file="present", pg_config=None, report_db=True),\ | 698 | with self.subTest(file="present", pg_connected=None, report_db=True),\ |
699 | mock.patch.object(dbs, "psql_connected") as psql,\ | ||
700 | mock.patch.object(dbs, "redis_connected") as redis,\ | ||
699 | mock.patch.object(m, "report") as report,\ | 701 | mock.patch.object(m, "report") as report,\ |
700 | mock.patch.object(m, "store_file_report") as file_report,\ | 702 | mock.patch.object(m, "store_file_report") as file_report,\ |
701 | mock.patch.object(m, "store_redis_report") as redis_report,\ | 703 | mock.patch.object(m, "store_redis_report") as redis_report,\ |
702 | mock.patch.object(m, "store_database_report") as db_report,\ | 704 | mock.patch.object(m, "store_database_report") as db_report,\ |
703 | mock.patch.object(market.datetime, "datetime") as time_mock: | 705 | mock.patch.object(market.datetime, "datetime") as time_mock: |
704 | 706 | psql.return_value = False | |
707 | redis.return_value = False | ||
705 | time_mock.now.return_value = datetime.datetime(2018, 2, 25) | 708 | time_mock.now.return_value = datetime.datetime(2018, 2, 25) |
706 | 709 | ||
707 | m.store_report() | 710 | m.store_report() |
@@ -712,14 +715,17 @@ class MarketTest(WebMockTestCase): | |||
712 | redis_report.assert_not_called() | 715 | redis_report.assert_not_called() |
713 | 716 | ||
714 | report.reset_mock() | 717 | report.reset_mock() |
715 | m = market.Market(self.ccxt, self.market_args(report_db=True), pg_config="present", user_id=1) | 718 | m = market.Market(self.ccxt, self.market_args(report_db=True), user_id=1) |
716 | with self.subTest(file=None, pg_config="present"),\ | 719 | with self.subTest(file=None, pg_connected=True),\ |
720 | mock.patch.object(dbs, "psql_connected") as psql,\ | ||
721 | mock.patch.object(dbs, "redis_connected") as redis,\ | ||
717 | mock.patch.object(m, "report") as report,\ | 722 | mock.patch.object(m, "report") as report,\ |
718 | mock.patch.object(m, "store_file_report") as file_report,\ | 723 | mock.patch.object(m, "store_file_report") as file_report,\ |
719 | mock.patch.object(m, "store_redis_report") as redis_report,\ | 724 | mock.patch.object(m, "store_redis_report") as redis_report,\ |
720 | mock.patch.object(m, "store_database_report") as db_report,\ | 725 | mock.patch.object(m, "store_database_report") as db_report,\ |
721 | mock.patch.object(market.datetime, "datetime") as time_mock: | 726 | mock.patch.object(market.datetime, "datetime") as time_mock: |
722 | 727 | psql.return_value = True | |
728 | redis.return_value = False | ||
723 | time_mock.now.return_value = datetime.datetime(2018, 2, 25) | 729 | time_mock.now.return_value = datetime.datetime(2018, 2, 25) |
724 | 730 | ||
725 | m.store_report() | 731 | m.store_report() |
@@ -731,14 +737,17 @@ class MarketTest(WebMockTestCase): | |||
731 | 737 | ||
732 | report.reset_mock() | 738 | report.reset_mock() |
733 | m = market.Market(self.ccxt, self.market_args(report_db=True, report_path="present"), | 739 | m = market.Market(self.ccxt, self.market_args(report_db=True, report_path="present"), |
734 | pg_config="pg_config", user_id=1) | 740 | user_id=1) |
735 | with self.subTest(file="present", pg_config="present"),\ | 741 | with self.subTest(file="present", pg_connected=True),\ |
742 | mock.patch.object(dbs, "psql_connected") as psql,\ | ||
743 | mock.patch.object(dbs, "redis_connected") as redis,\ | ||
736 | mock.patch.object(m, "report") as report,\ | 744 | mock.patch.object(m, "report") as report,\ |
737 | mock.patch.object(m, "store_file_report") as file_report,\ | 745 | mock.patch.object(m, "store_file_report") as file_report,\ |
738 | mock.patch.object(m, "store_redis_report") as redis_report,\ | 746 | mock.patch.object(m, "store_redis_report") as redis_report,\ |
739 | mock.patch.object(m, "store_database_report") as db_report,\ | 747 | mock.patch.object(m, "store_database_report") as db_report,\ |
740 | mock.patch.object(market.datetime, "datetime") as time_mock: | 748 | mock.patch.object(market.datetime, "datetime") as time_mock: |
741 | 749 | psql.return_value = True | |
750 | redis.return_value = False | ||
742 | time_mock.now.return_value = datetime.datetime(2018, 2, 25) | 751 | time_mock.now.return_value = datetime.datetime(2018, 2, 25) |
743 | 752 | ||
744 | m.store_report() | 753 | m.store_report() |
@@ -750,14 +759,17 @@ class MarketTest(WebMockTestCase): | |||
750 | 759 | ||
751 | report.reset_mock() | 760 | report.reset_mock() |
752 | m = market.Market(self.ccxt, self.market_args(report_redis=False), | 761 | m = market.Market(self.ccxt, self.market_args(report_redis=False), |
753 | redis_config="redis_config", user_id=1) | 762 | user_id=1) |
754 | with self.subTest(redis_config="present", report_redis=False),\ | 763 | with self.subTest(redis_connected=True, report_redis=False),\ |
764 | mock.patch.object(dbs, "psql_connected") as psql,\ | ||
765 | mock.patch.object(dbs, "redis_connected") as redis,\ | ||
755 | mock.patch.object(m, "report") as report,\ | 766 | mock.patch.object(m, "report") as report,\ |
756 | mock.patch.object(m, "store_file_report") as file_report,\ | 767 | mock.patch.object(m, "store_file_report") as file_report,\ |
757 | mock.patch.object(m, "store_redis_report") as redis_report,\ | 768 | mock.patch.object(m, "store_redis_report") as redis_report,\ |
758 | mock.patch.object(m, "store_database_report") as db_report,\ | 769 | mock.patch.object(m, "store_database_report") as db_report,\ |
759 | mock.patch.object(market.datetime, "datetime") as time_mock: | 770 | mock.patch.object(market.datetime, "datetime") as time_mock: |
760 | 771 | psql.return_value = False | |
772 | redis.return_value = True | ||
761 | time_mock.now.return_value = datetime.datetime(2018, 2, 25) | 773 | time_mock.now.return_value = datetime.datetime(2018, 2, 25) |
762 | 774 | ||
763 | m.store_report() | 775 | m.store_report() |
@@ -766,13 +778,16 @@ class MarketTest(WebMockTestCase): | |||
766 | report.reset_mock() | 778 | report.reset_mock() |
767 | m = market.Market(self.ccxt, self.market_args(report_redis=True), | 779 | m = market.Market(self.ccxt, self.market_args(report_redis=True), |
768 | user_id=1) | 780 | user_id=1) |
769 | with self.subTest(redis_config="absent", report_redis=True),\ | 781 | with self.subTest(redis_connected=False, report_redis=True),\ |
782 | mock.patch.object(dbs, "psql_connected") as psql,\ | ||
783 | mock.patch.object(dbs, "redis_connected") as redis,\ | ||
770 | mock.patch.object(m, "report") as report,\ | 784 | mock.patch.object(m, "report") as report,\ |
771 | mock.patch.object(m, "store_file_report") as file_report,\ | 785 | mock.patch.object(m, "store_file_report") as file_report,\ |
772 | mock.patch.object(m, "store_redis_report") as redis_report,\ | 786 | mock.patch.object(m, "store_redis_report") as redis_report,\ |
773 | mock.patch.object(m, "store_database_report") as db_report,\ | 787 | mock.patch.object(m, "store_database_report") as db_report,\ |
774 | mock.patch.object(market.datetime, "datetime") as time_mock: | 788 | mock.patch.object(market.datetime, "datetime") as time_mock: |
775 | 789 | psql.return_value = False | |
790 | redis.return_value = False | ||
776 | time_mock.now.return_value = datetime.datetime(2018, 2, 25) | 791 | time_mock.now.return_value = datetime.datetime(2018, 2, 25) |
777 | 792 | ||
778 | m.store_report() | 793 | m.store_report() |
@@ -780,14 +795,17 @@ class MarketTest(WebMockTestCase): | |||
780 | 795 | ||
781 | report.reset_mock() | 796 | report.reset_mock() |
782 | m = market.Market(self.ccxt, self.market_args(report_redis=True), | 797 | m = market.Market(self.ccxt, self.market_args(report_redis=True), |
783 | redis_config="redis_config", user_id=1) | 798 | user_id=1) |
784 | with self.subTest(redis_config="present", report_redis=True),\ | 799 | with self.subTest(redis_connected=True, report_redis=True),\ |
800 | mock.patch.object(dbs, "psql_connected") as psql,\ | ||
801 | mock.patch.object(dbs, "redis_connected") as redis,\ | ||
785 | mock.patch.object(m, "report") as report,\ | 802 | mock.patch.object(m, "report") as report,\ |
786 | mock.patch.object(m, "store_file_report") as file_report,\ | 803 | mock.patch.object(m, "store_file_report") as file_report,\ |
787 | mock.patch.object(m, "store_redis_report") as redis_report,\ | 804 | mock.patch.object(m, "store_redis_report") as redis_report,\ |
788 | mock.patch.object(m, "store_database_report") as db_report,\ | 805 | mock.patch.object(m, "store_database_report") as db_report,\ |
789 | mock.patch.object(market.datetime, "datetime") as time_mock: | 806 | mock.patch.object(market.datetime, "datetime") as time_mock: |
790 | 807 | psql.return_value = False | |
808 | redis.return_value = True | ||
791 | time_mock.now.return_value = datetime.datetime(2018, 2, 25) | 809 | time_mock.now.return_value = datetime.datetime(2018, 2, 25) |
792 | 810 | ||
793 | m.store_report() | 811 | m.store_report() |
@@ -1014,6 +1032,15 @@ class ProcessorTest(WebMockTestCase): | |||
1014 | processor.process_step("foo", step, {"foo":"bar"}) | 1032 | processor.process_step("foo", step, {"foo":"bar"}) |
1015 | self.m.balances.fetch_balances.assert_not_called() | 1033 | self.m.balances.fetch_balances.assert_not_called() |
1016 | 1034 | ||
1035 | self.m.reset_mock() | ||
1036 | with mock.patch.object(processor, "run_action") as run_action: | ||
1037 | step = processor.scenarios["print_balances"][0] | ||
1038 | |||
1039 | processor.process_step("foo", step, {"foo":"bar"}) | ||
1040 | self.m.balances.fetch_balances.assert_called_once_with( | ||
1041 | add_portfolio=True, log_tickers=True, | ||
1042 | tag='process_foo__1_print_balances_begin') | ||
1043 | |||
1017 | def test_parse_args(self): | 1044 | def test_parse_args(self): |
1018 | processor = market.Processor(self.m) | 1045 | processor = market.Processor(self.m) |
1019 | 1046 | ||
diff --git a/tests/test_store.py b/tests/test_store.py index 12999d3..ee7e063 100644 --- a/tests/test_store.py +++ b/tests/test_store.py | |||
@@ -391,6 +391,18 @@ class BalanceStoreTest(WebMockTestCase): | |||
391 | tag=None, ticker_currency='FOO', tickers='tickers', | 391 | tag=None, ticker_currency='FOO', tickers='tickers', |
392 | type='type') | 392 | type='type') |
393 | 393 | ||
394 | balance_store = market.BalanceStore(self.m) | ||
395 | |||
396 | with self.subTest(add_portfolio=True),\ | ||
397 | mock.patch.object(market.Portfolio, "repartition") as repartition: | ||
398 | repartition.return_value = { | ||
399 | "DOGE": D("0.5"), | ||
400 | "USDT": D("0.5"), | ||
401 | } | ||
402 | balance_store.fetch_balances(add_portfolio=True) | ||
403 | self.assertListEqual(["USDT", "XVG", "XMR", "DOGE"], list(balance_store.currencies())) | ||
404 | |||
405 | |||
394 | @mock.patch.object(market.Portfolio, "repartition") | 406 | @mock.patch.object(market.Portfolio, "repartition") |
395 | def test_dispatch_assets(self, repartition): | 407 | def test_dispatch_assets(self, repartition): |
396 | self.m.ccxt.fetch_all_balances.return_value = self.fetch_balance | 408 | self.m.ccxt.fetch_all_balances.return_value = self.fetch_balance |
@@ -1101,7 +1113,8 @@ class PortfolioTest(WebMockTestCase): | |||
1101 | self.wm.get(market.Portfolio.URL, text=self.json_response) | 1113 | self.wm.get(market.Portfolio.URL, text=self.json_response) |
1102 | 1114 | ||
1103 | @mock.patch.object(market.Portfolio, "parse_cryptoportfolio") | 1115 | @mock.patch.object(market.Portfolio, "parse_cryptoportfolio") |
1104 | def test_get_cryptoportfolio(self, parse_cryptoportfolio): | 1116 | @mock.patch.object(market.Portfolio, "store_cryptoportfolio") |
1117 | def test_get_cryptoportfolio(self, store_cryptoportfolio, parse_cryptoportfolio): | ||
1105 | with self.subTest(parallel=False): | 1118 | with self.subTest(parallel=False): |
1106 | self.wm.get(market.Portfolio.URL, [ | 1119 | self.wm.get(market.Portfolio.URL, [ |
1107 | {"text":'{ "foo": "bar" }', "status_code": 200}, | 1120 | {"text":'{ "foo": "bar" }', "status_code": 200}, |
@@ -1116,23 +1129,28 @@ class PortfolioTest(WebMockTestCase): | |||
1116 | market.Portfolio.report.log_error.assert_not_called() | 1129 | market.Portfolio.report.log_error.assert_not_called() |
1117 | market.Portfolio.report.log_http_request.assert_called_once() | 1130 | market.Portfolio.report.log_http_request.assert_called_once() |
1118 | parse_cryptoportfolio.assert_called_once_with() | 1131 | parse_cryptoportfolio.assert_called_once_with() |
1132 | store_cryptoportfolio.assert_called_once_with() | ||
1119 | market.Portfolio.report.log_http_request.reset_mock() | 1133 | market.Portfolio.report.log_http_request.reset_mock() |
1120 | parse_cryptoportfolio.reset_mock() | 1134 | parse_cryptoportfolio.reset_mock() |
1135 | store_cryptoportfolio.reset_mock() | ||
1121 | market.Portfolio.data = store.LockedVar(None) | 1136 | market.Portfolio.data = store.LockedVar(None) |
1122 | 1137 | ||
1123 | market.Portfolio.get_cryptoportfolio() | 1138 | market.Portfolio.get_cryptoportfolio() |
1124 | self.assertIsNone(market.Portfolio.data.get()) | 1139 | self.assertIsNone(market.Portfolio.data.get()) |
1125 | self.assertEqual(2, self.wm.call_count) | 1140 | self.assertEqual(2, self.wm.call_count) |
1126 | parse_cryptoportfolio.assert_not_called() | 1141 | parse_cryptoportfolio.assert_not_called() |
1142 | store_cryptoportfolio.assert_not_called() | ||
1127 | market.Portfolio.report.log_error.assert_not_called() | 1143 | market.Portfolio.report.log_error.assert_not_called() |
1128 | market.Portfolio.report.log_http_request.assert_called_once() | 1144 | market.Portfolio.report.log_http_request.assert_called_once() |
1129 | market.Portfolio.report.log_http_request.reset_mock() | 1145 | market.Portfolio.report.log_http_request.reset_mock() |
1130 | parse_cryptoportfolio.reset_mock() | 1146 | parse_cryptoportfolio.reset_mock() |
1147 | store_cryptoportfolio.reset_mock() | ||
1131 | 1148 | ||
1132 | market.Portfolio.data = store.LockedVar("Foo") | 1149 | market.Portfolio.data = store.LockedVar("Foo") |
1133 | market.Portfolio.get_cryptoportfolio() | 1150 | market.Portfolio.get_cryptoportfolio() |
1134 | self.assertEqual(2, self.wm.call_count) | 1151 | self.assertEqual(2, self.wm.call_count) |
1135 | parse_cryptoportfolio.assert_not_called() | 1152 | parse_cryptoportfolio.assert_not_called() |
1153 | store_cryptoportfolio.assert_not_called() | ||
1136 | 1154 | ||
1137 | market.Portfolio.get_cryptoportfolio(refetch=True) | 1155 | market.Portfolio.get_cryptoportfolio(refetch=True) |
1138 | self.assertEqual("Foo", market.Portfolio.data.get()) | 1156 | self.assertEqual("Foo", market.Portfolio.data.get()) |
@@ -1153,6 +1171,7 @@ class PortfolioTest(WebMockTestCase): | |||
1153 | market.Portfolio.get_cryptoportfolio() | 1171 | market.Portfolio.get_cryptoportfolio() |
1154 | self.assertIn("foo", market.Portfolio.data.get()) | 1172 | self.assertIn("foo", market.Portfolio.data.get()) |
1155 | parse_cryptoportfolio.reset_mock() | 1173 | parse_cryptoportfolio.reset_mock() |
1174 | store_cryptoportfolio.reset_mock() | ||
1156 | with self.subTest(worker=False): | 1175 | with self.subTest(worker=False): |
1157 | market.Portfolio.data = store.LockedVar(None) | 1176 | market.Portfolio.data = store.LockedVar(None) |
1158 | market.Portfolio.worker = mock.Mock() | 1177 | market.Portfolio.worker = mock.Mock() |
@@ -1160,6 +1179,7 @@ class PortfolioTest(WebMockTestCase): | |||
1160 | market.Portfolio.get_cryptoportfolio() | 1179 | market.Portfolio.get_cryptoportfolio() |
1161 | notify.assert_called_once_with() | 1180 | notify.assert_called_once_with() |
1162 | parse_cryptoportfolio.assert_not_called() | 1181 | parse_cryptoportfolio.assert_not_called() |
1182 | store_cryptoportfolio.assert_not_called() | ||
1163 | 1183 | ||
1164 | def test_parse_cryptoportfolio(self): | 1184 | def test_parse_cryptoportfolio(self): |
1165 | with self.subTest(description="Normal case"): | 1185 | with self.subTest(description="Normal case"): |
@@ -1223,25 +1243,95 @@ class PortfolioTest(WebMockTestCase): | |||
1223 | self.assertEqual({}, market.Portfolio.liquidities.get("high")) | 1243 | self.assertEqual({}, market.Portfolio.liquidities.get("high")) |
1224 | self.assertEqual(datetime.datetime(1,1,1), market.Portfolio.last_date.get()) | 1244 | self.assertEqual(datetime.datetime(1,1,1), market.Portfolio.last_date.get()) |
1225 | 1245 | ||
1226 | 1246 | @mock.patch.object(store.dbs, "redis_connected") | |
1227 | @mock.patch.object(market.Portfolio, "get_cryptoportfolio") | 1247 | @mock.patch.object(store.dbs, "redis") |
1228 | def test_repartition(self, get_cryptoportfolio): | 1248 | def test_store_cryptoportfolio(self, redis, redis_connected): |
1229 | market.Portfolio.liquidities = store.LockedVar({ | 1249 | store.Portfolio.liquidities = store.LockedVar({ |
1230 | "medium": { | 1250 | "medium": { |
1231 | "2018-03-01": "medium_2018-03-01", | 1251 | datetime.datetime(2018,3,1): "medium_2018-03-01", |
1232 | "2018-03-08": "medium_2018-03-08", | 1252 | datetime.datetime(2018,3,8): "medium_2018-03-08", |
1233 | }, | 1253 | }, |
1234 | "high": { | 1254 | "high": { |
1235 | "2018-03-01": "high_2018-03-01", | 1255 | datetime.datetime(2018,3,1): "high_2018-03-01", |
1236 | "2018-03-08": "high_2018-03-08", | 1256 | datetime.datetime(2018,3,8): "high_2018-03-08", |
1237 | } | 1257 | } |
1238 | }) | 1258 | }) |
1239 | market.Portfolio.last_date = store.LockedVar("2018-03-08") | 1259 | store.Portfolio.last_date = store.LockedVar(datetime.datetime(2018,3,8)) |
1260 | |||
1261 | with self.subTest(redis_connected=False): | ||
1262 | redis_connected.return_value = False | ||
1263 | store.Portfolio.store_cryptoportfolio() | ||
1264 | redis.set.assert_not_called() | ||
1265 | |||
1266 | with self.subTest(redis_connected=True): | ||
1267 | redis_connected.return_value = True | ||
1268 | store.Portfolio.store_cryptoportfolio() | ||
1269 | redis.set.assert_has_calls([ | ||
1270 | mock.call("/cryptoportfolio/repartition/latest", '{"medium": "medium_2018-03-08", "high": "high_2018-03-08"}'), | ||
1271 | mock.call("/cryptoportfolio/repartition/date", "2018-03-08"), | ||
1272 | ]) | ||
1273 | |||
1274 | @mock.patch.object(store.dbs, "redis_connected") | ||
1275 | @mock.patch.object(store.dbs, "redis") | ||
1276 | def test_retrieve_cryptoportfolio(self, redis, redis_connected): | ||
1277 | with self.subTest(redis_connected=False): | ||
1278 | redis_connected.return_value = False | ||
1279 | store.Portfolio.retrieve_cryptoportfolio() | ||
1280 | redis.get.assert_not_called() | ||
1281 | self.assertIsNone(store.Portfolio.data.get()) | ||
1282 | |||
1283 | with self.subTest(redis_connected=True, value=None): | ||
1284 | redis_connected.return_value = True | ||
1285 | redis.get.return_value = None | ||
1286 | store.Portfolio.retrieve_cryptoportfolio() | ||
1287 | self.assertEqual(2, redis.get.call_count) | ||
1288 | |||
1289 | redis.reset_mock() | ||
1290 | with self.subTest(redis_connected=True, value="present"): | ||
1291 | redis_connected.return_value = True | ||
1292 | redis.get.side_effect = [ | ||
1293 | b'{ "medium": "medium_repartition", "high": "high_repartition" }', | ||
1294 | b"2018-03-08" | ||
1295 | ] | ||
1296 | store.Portfolio.retrieve_cryptoportfolio() | ||
1297 | self.assertEqual(2, redis.get.call_count) | ||
1298 | self.assertEqual(datetime.datetime(2018,3,8), store.Portfolio.last_date.get()) | ||
1299 | self.assertEqual("", store.Portfolio.data.get()) | ||
1300 | expected_liquidities = { | ||
1301 | 'high': { datetime.datetime(2018, 3, 8): 'high_repartition' }, | ||
1302 | 'medium': { datetime.datetime(2018, 3, 8): 'medium_repartition' }, | ||
1303 | } | ||
1304 | self.assertEqual(expected_liquidities, store.Portfolio.liquidities.get()) | ||
1305 | |||
1306 | @mock.patch.object(market.Portfolio, "get_cryptoportfolio") | ||
1307 | @mock.patch.object(market.Portfolio, "retrieve_cryptoportfolio") | ||
1308 | def test_repartition(self, retrieve_cryptoportfolio, get_cryptoportfolio): | ||
1309 | with self.subTest(from_cache=False): | ||
1310 | market.Portfolio.liquidities = store.LockedVar({ | ||
1311 | "medium": { | ||
1312 | "2018-03-01": "medium_2018-03-01", | ||
1313 | "2018-03-08": "medium_2018-03-08", | ||
1314 | }, | ||
1315 | "high": { | ||
1316 | "2018-03-01": "high_2018-03-01", | ||
1317 | "2018-03-08": "high_2018-03-08", | ||
1318 | } | ||
1319 | }) | ||
1320 | market.Portfolio.last_date = store.LockedVar("2018-03-08") | ||
1321 | |||
1322 | self.assertEqual("medium_2018-03-08", market.Portfolio.repartition()) | ||
1323 | get_cryptoportfolio.assert_called_once_with() | ||
1324 | retrieve_cryptoportfolio.assert_not_called() | ||
1325 | self.assertEqual("medium_2018-03-08", market.Portfolio.repartition(liquidity="medium")) | ||
1326 | self.assertEqual("high_2018-03-08", market.Portfolio.repartition(liquidity="high")) | ||
1327 | |||
1328 | retrieve_cryptoportfolio.reset_mock() | ||
1329 | get_cryptoportfolio.reset_mock() | ||
1240 | 1330 | ||
1241 | self.assertEqual("medium_2018-03-08", market.Portfolio.repartition()) | 1331 | with self.subTest(from_cache=True): |
1242 | get_cryptoportfolio.assert_called_once_with() | 1332 | self.assertEqual("medium_2018-03-08", market.Portfolio.repartition(from_cache=True)) |
1243 | self.assertEqual("medium_2018-03-08", market.Portfolio.repartition(liquidity="medium")) | 1333 | get_cryptoportfolio.assert_called_once_with() |
1244 | self.assertEqual("high_2018-03-08", market.Portfolio.repartition(liquidity="high")) | 1334 | retrieve_cryptoportfolio.assert_called_once_with() |
1245 | 1335 | ||
1246 | @mock.patch.object(market.time, "sleep") | 1336 | @mock.patch.object(market.time, "sleep") |
1247 | @mock.patch.object(market.Portfolio, "get_cryptoportfolio") | 1337 | @mock.patch.object(market.Portfolio, "get_cryptoportfolio") |