From e7d7c0e5645da35adcbfec9e51deb68f012c422f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Isma=C3=ABl=20Bouya?= Date: Sat, 7 Apr 2018 17:39:29 +0200 Subject: [PATCH] Acceptance test preparation Save some headers for http requests Wait for all threads after the end of main Simplify library imports for mocking --- ccxt_wrapper.py | 2 + main.py | 11 ++- market.py | 5 +- portfolio.py | 6 +- store.py | 72 +++++++++++----- tests/test_ccxt_wrapper.py | 12 ++- tests/test_main.py | 4 +- tests/test_market.py | 8 +- tests/test_portfolio.py | 4 +- tests/test_store.py | 167 ++++++++++++++++++++++++++----------- 10 files changed, 204 insertions(+), 87 deletions(-) diff --git a/ccxt_wrapper.py b/ccxt_wrapper.py index bedf84b..366586c 100644 --- a/ccxt_wrapper.py +++ b/ccxt_wrapper.py @@ -47,6 +47,8 @@ class poloniexE(poloniex): self.session._parent = self def request_wrap(self, *args, **kwargs): + kwargs["headers"]["X-market-id"] = str(self._parent._market.market_id) + kwargs["headers"]["X-user-id"] = str(self._parent._market.user_id) try: r = self.origin_request(*args, **kwargs) self._parent._market.report.log_http_request(args[0], diff --git a/main.py b/main.py index 6383ed1..2cfb01d 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,3 @@ -from datetime import datetime import configargparse import psycopg2 import os @@ -170,13 +169,21 @@ def main(argv): import threading market.Portfolio.start_worker() + threads = [] def process_(*args): - threading.Thread(target=process, args=args).start() + thread = threading.Thread(target=process, args=args) + thread.start() + threads.append(thread) else: process_ = process for market_id, market_config, user_id in fetch_markets(pg_config, args.user): process_(market_config, market_id, user_id, args, pg_config) + if args.parallel: + for thread in threads: + thread.join() + market.Portfolio.stop_worker() + if __name__ == '__main__': # pragma: no cover main(sys.argv[1:]) diff --git a/market.py b/market.py index e16641c..7a37cf6 100644 --- a/market.py +++ b/market.py @@ -5,6 +5,7 @@ import psycopg2 from store import * from cachetools.func import ttl_cache from datetime import datetime +import datetime from retry import retry import portfolio @@ -28,7 +29,7 @@ class Market: for key in ["user_id", "market_id", "pg_config"]: setattr(self, key, kwargs.get(key, None)) - self.report.log_market(self.args, self.user_id, self.market_id) + self.report.log_market(self.args) @classmethod def from_config(cls, config, args, **kwargs): @@ -40,7 +41,7 @@ class Market: def store_report(self): self.report.merge(Portfolio.report) - date = datetime.now() + date = datetime.datetime.now() if self.args.report_path is not None: self.store_file_report(date) if self.pg_config is not None and self.args.report_db: diff --git a/portfolio.py b/portfolio.py index 535aaa8..146ee79 100644 --- a/portfolio.py +++ b/portfolio.py @@ -1,4 +1,4 @@ -from datetime import datetime +import datetime from retry import retry from decimal import Decimal as D, ROUND_DOWN from ccxt import ExchangeError, InsufficientFunds, ExchangeNotAvailable, InvalidOrder, OrderNotCached, OrderNotFound, RequestTimeout, InvalidNonce @@ -492,7 +492,7 @@ class Order: self.market.report.log_debug_action(action) self.results.append({"debug": True, "id": -1}) else: - self.start_date = datetime.now() + self.start_date = datetime.datetime.now() try: self.results.append(self.market.ccxt.create_order(symbol, 'limit', self.action, amount, price=self.rate, account=self.account)) except InvalidOrder: @@ -677,7 +677,7 @@ class Mouvement: self.action = hash_.get("type") self.fee_rate = D(hash_.get("fee", -1)) try: - self.date = datetime.strptime(hash_.get("date", ""), '%Y-%m-%d %H:%M:%S') + self.date = datetime.datetime.strptime(hash_.get("date", ""), '%Y-%m-%d %H:%M:%S') except ValueError: self.date = None self.rate = D(hash_.get("rate", 0)) diff --git a/store.py b/store.py index 67e8a8f..467dd4b 100644 --- a/store.py +++ b/store.py @@ -3,7 +3,7 @@ import requests import portfolio import simplejson as json from decimal import Decimal as D, ROUND_DOWN -from datetime import date, datetime, timedelta +import datetime import inspect from json import JSONDecodeError from simplejson.errors import JSONDecodeError as SimpleJSONDecodeError @@ -11,13 +11,16 @@ from simplejson.errors import JSONDecodeError as SimpleJSONDecodeError __all__ = ["Portfolio", "BalanceStore", "ReportStore", "TradeStore"] class ReportStore: - def __init__(self, market, verbose_print=True): + def __init__(self, market, verbose_print=True, no_http_dup=False): self.market = market self.verbose_print = verbose_print self.print_logs = [] self.logs = [] + self.no_http_dup = no_http_dup + self.last_http = None + def merge(self, other_report): self.logs += other_report.logs self.logs.sort(key=lambda x: x["date"]) @@ -26,19 +29,26 @@ class ReportStore: self.print_logs.sort(key=lambda x: x[0]) def print_log(self, message): - now = datetime.now() + now = datetime.datetime.now() message = "{:%Y-%m-%d %H:%M:%S}: {}".format(now, str(message)) self.print_logs.append([now, message]) if self.verbose_print: print(message) def add_log(self, hash_): - hash_["date"] = datetime.now() + hash_["date"] = datetime.datetime.now() + if self.market is not None: + hash_["user_id"] = self.market.user_id + hash_["market_id"] = self.market.market_id + else: + hash_["user_id"] = None + hash_["market_id"] = None self.logs.append(hash_) + return hash_ @staticmethod def default_json_serial(obj): - if isinstance(obj, (datetime, date)): + if isinstance(obj, (datetime.datetime, datetime.date)): return obj.isoformat() return str(obj) @@ -188,7 +198,12 @@ class ReportStore: "error": response.__class__.__name__, "error_message": str(response), }) - else: + self.last_http = None + elif self.no_http_dup and \ + self.last_http is not None and \ + self.last_http["url"] == url and \ + self.last_http["method"] == method and \ + self.last_http["response"] == response.text: self.add_log({ "type": "http_request", "method": method, @@ -196,7 +211,19 @@ class ReportStore: "body": body, "headers": headers, "status": response.status_code, - "response": response.text + "response": None, + "response_same_as": self.last_http["date"] + }) + else: + self.last_http = self.add_log({ + "type": "http_request", + "method": method, + "url": url, + "body": body, + "headers": headers, + "status": response.status_code, + "response": response.text, + "response_same_as": None, }) def log_error(self, action, message=None, exception=None): @@ -222,13 +249,11 @@ class ReportStore: "action": action, }) - def log_market(self, args, user_id, market_id): + def log_market(self, args): self.add_log({ "type": "market", "commit": "$Format:%H$", "args": vars(args), - "user_id": user_id, - "market_id": market_id, }) class BalanceStore: @@ -382,7 +407,7 @@ class Portfolio: data = LockedVar(None) liquidities = LockedVar({}) last_date = LockedVar(None) - report = LockedVar(ReportStore(None)) + report = LockedVar(ReportStore(None, no_http_dup=True)) worker = None worker_started = False worker_notify = None @@ -418,11 +443,17 @@ class Portfolio: raise RuntimeError("This method needs to be ran with the worker") while cls.worker_started: cls.worker_notify.wait() - cls.worker_notify.clear() - cls.report.print_log("Fetching cryptoportfolio") - cls.get_cryptoportfolio(refetch=True) - cls.callback.set() - time.sleep(poll) + if cls.worker_started: + cls.worker_notify.clear() + cls.report.print_log("Fetching cryptoportfolio") + cls.get_cryptoportfolio(refetch=True) + cls.callback.set() + time.sleep(poll) + + @classmethod + def stop_worker(cls): + cls.worker_started = False + cls.worker_notify.set() @classmethod def notify_and_wait(cls): @@ -433,7 +464,7 @@ class Portfolio: @classmethod def wait_for_recent(cls, delta=4, poll=30): cls.get_cryptoportfolio() - while cls.last_date.get() is None or datetime.now() - cls.last_date.get() > timedelta(delta): + while cls.last_date.get() is None or datetime.datetime.now() - cls.last_date.get() > datetime.timedelta(delta): if cls.worker is None: time.sleep(poll) cls.report.print_log("Attempt to fetch up-to-date cryptoportfolio") @@ -490,7 +521,7 @@ class Portfolio: weights_hash = portfolio_hash["weights"] weights = {} for i in range(len(weights_hash["_row"])): - date = datetime.strptime(weights_hash["_row"][i], "%Y-%m-%d") + date = datetime.datetime.strptime(weights_hash["_row"][i], "%Y-%m-%d") weights[date] = dict(filter( filter_weights, map(clean_weights(i), weights_hash.items()))) @@ -504,8 +535,7 @@ class Portfolio: "high": high_liquidity, }) cls.last_date.set(max( - max(medium_liquidity.keys(), default=datetime(1, 1, 1)), - max(high_liquidity.keys(), default=datetime(1, 1, 1)) + max(medium_liquidity.keys(), default=datetime.datetime(1, 1, 1)), + max(high_liquidity.keys(), default=datetime.datetime(1, 1, 1)) )) - diff --git a/tests/test_ccxt_wrapper.py b/tests/test_ccxt_wrapper.py index d32469a..597fe5c 100644 --- a/tests/test_ccxt_wrapper.py +++ b/tests/test_ccxt_wrapper.py @@ -22,11 +22,13 @@ class poloniexETest(unittest.TestCase): ccxt = market.ccxt.poloniexE() ccxt._market = mock.Mock ccxt._market.report = mock.Mock() + ccxt._market.market_id = 3 + ccxt._market.user_id = 3 ccxt.session.request("GET", "URL", data="data", - headers="headers") + headers={}) ccxt._market.report.log_http_request.assert_called_with('GET', 'URL', 'data', - 'headers', 'response') + {'X-market-id': '3', 'X-user-id': '3'}, 'response') with self.subTest("Raising"),\ mock.patch("market.ccxt.poloniexE.session") as session: @@ -35,12 +37,14 @@ class poloniexETest(unittest.TestCase): ccxt = market.ccxt.poloniexE() ccxt._market = mock.Mock ccxt._market.report = mock.Mock() + ccxt._market.market_id = 3 + ccxt._market.user_id = 3 with self.assertRaises(market.ccxt.RequestException, msg="Boo") as cm: ccxt.session.request("GET", "URL", data="data", - headers="headers") + headers={}) ccxt._market.report.log_http_request.assert_called_with('GET', 'URL', 'data', - 'headers', cm.exception) + {'X-market-id': '3', 'X-user-id': '3'}, cm.exception) def test_nanoseconds(self): diff --git a/tests/test_main.py b/tests/test_main.py index 6396c07..e3a5677 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -179,7 +179,8 @@ class MainTest(WebMockTestCase): mock.patch("main.parse_config") as parse_config,\ mock.patch("main.fetch_markets") as fetch_markets,\ mock.patch("main.process") as process,\ - mock.patch("store.Portfolio.start_worker") as start: + mock.patch("store.Portfolio.start_worker") as start,\ + mock.patch("store.Portfolio.stop_worker") as stop: args_mock = mock.Mock() args_mock.parallel = True @@ -196,6 +197,7 @@ class MainTest(WebMockTestCase): parse_config.assert_called_with(args_mock) fetch_markets.assert_called_with("pg_config", "user") + stop.assert_called_once_with() start.assert_called_once_with() self.assertEqual(2, process.call_count) process.assert_has_calls([ diff --git a/tests/test_market.py b/tests/test_market.py index 82eeea8..14b23b5 100644 --- a/tests/test_market.py +++ b/tests/test_market.py @@ -548,7 +548,7 @@ class MarketTest(WebMockTestCase): mock.patch.object(m, "report") as report,\ mock.patch.object(m, "store_file_report") as file_report,\ mock.patch.object(m, "store_database_report") as db_report,\ - mock.patch.object(market, "datetime") as time_mock: + mock.patch.object(market.datetime, "datetime") as time_mock: time_mock.now.return_value = datetime.datetime(2018, 2, 25) @@ -564,7 +564,7 @@ class MarketTest(WebMockTestCase): mock.patch.object(m, "report") as report,\ mock.patch.object(m, "store_file_report") as file_report,\ mock.patch.object(m, "store_database_report") as db_report,\ - mock.patch.object(market, "datetime") as time_mock: + mock.patch.object(market.datetime, "datetime") as time_mock: time_mock.now.return_value = datetime.datetime(2018, 2, 25) @@ -580,7 +580,7 @@ class MarketTest(WebMockTestCase): mock.patch.object(m, "report") as report,\ mock.patch.object(m, "store_file_report") as file_report,\ mock.patch.object(m, "store_database_report") as db_report,\ - mock.patch.object(market, "datetime") as time_mock: + mock.patch.object(market.datetime, "datetime") as time_mock: time_mock.now.return_value = datetime.datetime(2018, 2, 25) @@ -597,7 +597,7 @@ class MarketTest(WebMockTestCase): mock.patch.object(m, "report") as report,\ mock.patch.object(m, "store_file_report") as file_report,\ mock.patch.object(m, "store_database_report") as db_report,\ - mock.patch.object(market, "datetime") as time_mock: + mock.patch.object(market.datetime, "datetime") as time_mock: time_mock.now.return_value = datetime.datetime(2018, 2, 25) diff --git a/tests/test_portfolio.py b/tests/test_portfolio.py index a1b95bf..98048ac 100644 --- a/tests/test_portfolio.py +++ b/tests/test_portfolio.py @@ -1742,7 +1742,7 @@ class MouvementTest(WebMockTestCase): self.assertEqual(42, mouvement.id) self.assertEqual("buy", mouvement.action) self.assertEqual(D("0.0015"), mouvement.fee_rate) - self.assertEqual(portfolio.datetime(2017, 12, 30, 12, 0, 12), mouvement.date) + self.assertEqual(portfolio.datetime.datetime(2017, 12, 30, 12, 0, 12), mouvement.date) self.assertEqual(D("0.1"), mouvement.rate) self.assertEqual(portfolio.Amount("ETH", "10"), mouvement.total) self.assertEqual(portfolio.Amount("BTC", "1"), mouvement.total_in_base) @@ -1780,7 +1780,7 @@ class MouvementTest(WebMockTestCase): as_json = mouvement.as_json() self.assertEqual(D("0.0015"), as_json["fee_rate"]) - self.assertEqual(portfolio.datetime(2017, 12, 30, 12, 0, 12), as_json["date"]) + self.assertEqual(portfolio.datetime.datetime(2017, 12, 30, 12, 0, 12), as_json["date"]) self.assertEqual("buy", as_json["action"]) self.assertEqual(D("10"), as_json["total"]) self.assertEqual(D("1"), as_json["total_in_base"]) diff --git a/tests/test_store.py b/tests/test_store.py index c0b1fb9..2b51719 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -444,10 +444,20 @@ class BalanceStoreTest(WebMockTestCase): @unittest.skipUnless("unit" in limits, "Unit skipped") class ReportStoreTest(WebMockTestCase): def test_add_log(self): - report_store = market.ReportStore(self.m) - report_store.add_log({"foo": "bar"}) + with self.subTest(market=self.m): + self.m.user_id = 1 + self.m.market_id = 3 + report_store = market.ReportStore(self.m) + result = report_store.add_log({"foo": "bar"}) + + self.assertEqual({"foo": "bar", "date": mock.ANY, "user_id": 1, "market_id": 3}, result) + self.assertEqual(result, report_store.logs[0]) + + with self.subTest(market=None): + report_store = market.ReportStore(None) + result = report_store.add_log({"foo": "bar"}) - self.assertEqual({"foo": "bar", "date": mock.ANY}, report_store.logs[0]) + self.assertEqual({"foo": "bar", "date": mock.ANY, "user_id": None, "market_id": None}, result) def test_set_verbose(self): report_store = market.ReportStore(self.m) @@ -460,6 +470,8 @@ class ReportStoreTest(WebMockTestCase): self.assertFalse(report_store.verbose_print) def test_merge(self): + self.m.user_id = 1 + self.m.market_id = 3 report_store1 = market.ReportStore(self.m, verbose_print=False) report_store2 = market.ReportStore(None, verbose_print=False) @@ -478,7 +490,7 @@ class ReportStoreTest(WebMockTestCase): with self.subTest(verbose=True),\ mock.patch.object(store, "datetime") as time_mock,\ mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock: - time_mock.now.return_value = datetime.datetime(2018, 2, 25, 2, 20, 10) + time_mock.datetime.now.return_value = datetime.datetime(2018, 2, 25, 2, 20, 10) report_store.set_verbose(True) report_store.print_log("Coucou") report_store.print_log(portfolio.Amount("BTC", 1)) @@ -495,7 +507,7 @@ class ReportStoreTest(WebMockTestCase): report_store = market.ReportStore(self.m) self.assertEqual("2018-02-24T00:00:00", - report_store.default_json_serial(portfolio.datetime(2018, 2, 24))) + report_store.default_json_serial(portfolio.datetime.datetime(2018, 2, 24))) self.assertEqual("1.00000000 BTC", report_store.default_json_serial(portfolio.Amount("BTC", 1))) @@ -503,7 +515,7 @@ class ReportStoreTest(WebMockTestCase): report_store = market.ReportStore(self.m) report_store.logs.append({"foo": "bar"}) self.assertEqual('[\n {\n "foo": "bar"\n }\n]', report_store.to_json()) - report_store.logs.append({"date": portfolio.datetime(2018, 2, 24)}) + report_store.logs.append({"date": portfolio.datetime.datetime(2018, 2, 24)}) self.assertEqual('[\n {\n "foo": "bar"\n },\n {\n "date": "2018-02-24T00:00:00"\n }\n]', report_store.to_json()) report_store.logs.append({"amount": portfolio.Amount("BTC", 1)}) self.assertEqual('[\n {\n "foo": "bar"\n },\n {\n "date": "2018-02-24T00:00:00"\n },\n {\n "amount": "1.00000000 BTC"\n }\n]', report_store.to_json()) @@ -817,53 +829,99 @@ class ReportStoreTest(WebMockTestCase): } }) - @mock.patch.object(market.ReportStore, "print_log") - @mock.patch.object(market.ReportStore, "add_log") - def test_log_http_request(self, add_log, print_log): - report_store = market.ReportStore(self.m) - response = mock.Mock() - response.status_code = 200 - response.text = "Hey" + def test_log_http_request(self): + with mock.patch.object(market.ReportStore, "add_log") as add_log: + report_store = market.ReportStore(self.m) + response = mock.Mock() + response.status_code = 200 + response.text = "Hey" - report_store.log_http_request("method", "url", "body", - "headers", response) - print_log.assert_not_called() - add_log.assert_called_once_with({ - 'type': 'http_request', - 'method': 'method', - 'url': 'url', - 'body': 'body', - 'headers': 'headers', - 'status': 200, - 'response': 'Hey' - }) + report_store.log_http_request("method", "url", "body", + "headers", response) + add_log.assert_called_once_with({ + 'type': 'http_request', + 'method': 'method', + 'url': 'url', + 'body': 'body', + 'headers': 'headers', + 'status': 200, + 'response': 'Hey', + 'response_same_as': None, + }) - add_log.reset_mock() - report_store.log_http_request("method", "url", "body", - "headers", ValueError("Foo")) - add_log.assert_called_once_with({ - 'type': 'http_request', - 'method': 'method', - 'url': 'url', - 'body': 'body', - 'headers': 'headers', - 'status': -1, - 'response': None, - 'error': 'ValueError', - 'error_message': 'Foo', - }) + add_log.reset_mock() + report_store.log_http_request("method", "url", "body", + "headers", ValueError("Foo")) + add_log.assert_called_once_with({ + 'type': 'http_request', + 'method': 'method', + 'url': 'url', + 'body': 'body', + 'headers': 'headers', + 'status': -1, + 'response': None, + 'error': 'ValueError', + 'error_message': 'Foo', + }) + + with self.subTest(no_http_dup=True, duplicate=True): + self.m.user_id = 1 + self.m.market_id = 3 + report_store = market.ReportStore(self.m, no_http_dup=True) + original_add_log = report_store.add_log + with mock.patch.object(report_store, "add_log", side_effect=original_add_log) as add_log: + report_store.log_http_request("method", "url", "body", + "headers", response) + report_store.log_http_request("method", "url", "body", + "headers", response) + self.assertEqual(2, add_log.call_count) + self.assertIsNone(add_log.mock_calls[0][1][0]["response_same_as"]) + self.assertIsNone(add_log.mock_calls[1][1][0]["response"]) + self.assertEqual(add_log.mock_calls[0][1][0]["date"], add_log.mock_calls[1][1][0]["response_same_as"]) + with self.subTest(no_http_dup=True, duplicate=False, case="Different call"): + self.m.user_id = 1 + self.m.market_id = 3 + report_store = market.ReportStore(self.m, no_http_dup=True) + original_add_log = report_store.add_log + with mock.patch.object(report_store, "add_log", side_effect=original_add_log) as add_log: + report_store.log_http_request("method", "url", "body", + "headers", response) + report_store.log_http_request("method2", "url", "body", + "headers", response) + self.assertEqual(2, add_log.call_count) + self.assertIsNone(add_log.mock_calls[0][1][0]["response_same_as"]) + self.assertIsNone(add_log.mock_calls[1][1][0]["response_same_as"]) + with self.subTest(no_http_dup=True, duplicate=False, case="Call inbetween"): + self.m.user_id = 1 + self.m.market_id = 3 + report_store = market.ReportStore(self.m, no_http_dup=True) + original_add_log = report_store.add_log + + response2 = mock.Mock() + response2.status_code = 200 + response2.text = "Hey there!" + + with mock.patch.object(report_store, "add_log", side_effect=original_add_log) as add_log: + report_store.log_http_request("method", "url", "body", + "headers", response) + report_store.log_http_request("method", "url", "body", + "headers", response2) + report_store.log_http_request("method", "url", "body", + "headers", response) + self.assertEqual(3, add_log.call_count) + self.assertIsNone(add_log.mock_calls[0][1][0]["response_same_as"]) + self.assertIsNone(add_log.mock_calls[1][1][0]["response_same_as"]) + self.assertIsNone(add_log.mock_calls[2][1][0]["response_same_as"]) @mock.patch.object(market.ReportStore, "add_log") def test_log_market(self, add_log): report_store = market.ReportStore(self.m) - report_store.log_market(self.market_args(debug=True, quiet=False), 4, 1) + report_store.log_market(self.market_args(debug=True, quiet=False)) add_log.assert_called_once_with({ "type": "market", "commit": "$Format:%H$", "args": { "report_path": None, "debug": True, "quiet": False }, - "user_id": 4, - "market_id": 1, }) @mock.patch.object(market.ReportStore, "print_log") @@ -1034,7 +1092,7 @@ class PortfolioTest(WebMockTestCase): 'SC': (D("0.0623"), "long"), 'ZEC': (D("0.3701"), "long"), } - date = portfolio.datetime(2018, 1, 8) + date = portfolio.datetime.datetime(2018, 1, 8) self.assertDictEqual(expected, liquidities["high"][date]) expected = { @@ -1051,7 +1109,7 @@ class PortfolioTest(WebMockTestCase): 'XCP': (D("0.1"), "long"), } self.assertDictEqual(expected, liquidities["medium"][date]) - self.assertEqual(portfolio.datetime(2018, 1, 15), market.Portfolio.last_date.get()) + self.assertEqual(portfolio.datetime.datetime(2018, 1, 15), market.Portfolio.last_date.get()) with self.subTest(description="Missing weight"): data = store.json.loads(self.json_response, parse_int=D, parse_float=D) @@ -1105,9 +1163,9 @@ class PortfolioTest(WebMockTestCase): else: self.assertFalse(refetch) self.call_count += 1 - market.Portfolio.last_date = store.LockedVar(store.datetime.now()\ - - store.timedelta(10)\ - + store.timedelta(self.call_count)) + market.Portfolio.last_date = store.LockedVar(store.datetime.datetime.now()\ + - store.datetime.timedelta(10)\ + + store.datetime.timedelta(self.call_count)) get_cryptoportfolio.side_effect = _get market.Portfolio.wait_for_recent() @@ -1166,6 +1224,19 @@ class PortfolioTest(WebMockTestCase): self.assertTrue(store.Portfolio.worker_started) self.assertFalse(store.Portfolio.worker.is_alive()) + self.assertEqual(1, threading.active_count()) + + def test_stop_worker(self): + with mock.patch.object(store.Portfolio, "get_cryptoportfolio") as get,\ + mock.patch.object(store.Portfolio, "report") as report,\ + mock.patch.object(store.time, "sleep") as sleep: + store.Portfolio.start_worker(poll=3) + store.Portfolio.stop_worker() + store.Portfolio.worker.join() + get.assert_not_called() + report.assert_not_called() + sleep.assert_not_called() + self.assertFalse(store.Portfolio.worker.is_alive()) def test_wait_for_notification(self): with self.assertRaises(RuntimeError): @@ -1189,7 +1260,7 @@ class PortfolioTest(WebMockTestCase): store.Portfolio.callback.clear() store.Portfolio.worker_started = False store.Portfolio.worker_notify.set() - store.Portfolio.callback.wait() + store.Portfolio.worker.join() self.assertFalse(store.Portfolio.worker.is_alive()) def test_notify_and_wait(self): -- 2.41.0