From 3080f31d1ee74104640dcff451922cd0ae88ee22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Bouya?= Date: Fri, 20 Apr 2018 20:09:13 +0200 Subject: Move acceptance tests to common directory --- test.py | 17 ++- test_acceptance.py | 361 --------------------------------------------- tests/acceptance.py | 358 ++++++++++++++++++++++++++++++++++++++++++++ tests/helper.py | 12 +- tests/test_acceptance.py | 25 ++++ tests/test_ccxt_wrapper.py | 3 +- tests/test_main.py | 1 + tests/test_market.py | 2 + tests/test_portfolio.py | 6 + tests/test_store.py | 6 + 10 files changed, 423 insertions(+), 368 deletions(-) delete mode 100644 test_acceptance.py create mode 100644 tests/acceptance.py create mode 100644 tests/test_acceptance.py diff --git a/test.py b/test.py index 8b9d35b..d7743b2 100644 --- a/test.py +++ b/test.py @@ -1,10 +1,17 @@ import unittest +from tests.acceptance import TimeMock -from tests.test_ccxt_wrapper import * -from tests.test_main import * -from tests.test_market import * -from tests.test_store import * -from tests.test_portfolio import * +from tests.helper import limits + +if "unit" in limits: + from tests.test_ccxt_wrapper import * + from tests.test_main import * + from tests.test_market import * + from tests.test_store import * + from tests.test_portfolio import * + +if "acceptance" in limits: + from tests.test_acceptance import * if __name__ == '__main__': unittest.main() diff --git a/test_acceptance.py b/test_acceptance.py deleted file mode 100644 index 3633928..0000000 --- a/test_acceptance.py +++ /dev/null @@ -1,361 +0,0 @@ -import requests -import requests_mock -import sys, os -import time, datetime -import unittest -from unittest import mock -from ssl import SSLError -from decimal import Decimal -import simplejson as json -import psycopg2 -from io import StringIO -import re -import functools -import glob - -import main - -class FileMock: - def __init__(self, log_files, quiet, tester): - self.tester = tester - self.log_files = [] - if log_files is not None and len(log_files) > 0: - self.read_log_files(log_files) - self.quiet = quiet - self.patches = [ - mock.patch("market.open"), - mock.patch("os.makedirs"), - mock.patch("sys.stdout", new_callable=StringIO), - ] - self.mocks = [] - - def start(self): - for patch in self.patches: - self.mocks.append(patch.start()) - self.stdout = self.mocks[-1] - - def check_calls(self): - stdout = self.stdout.getvalue() - if self.quiet: - self.tester.assertEqual("", stdout) - else: - log = self.strip_log(stdout) - if len(self.log_files) != 0: - split_logs = log.split("\n") - self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs)) - try: - for log_file in self.log_files: - for line in log_file: - split_logs.pop(split_logs.index(line)) - except ValueError: - if not line.startswith("[Worker] "): - self.tester.fail("« {} » not found in log file {}".format(line, split_logs)) - # Le fichier de log est écrit - # Le fichier de log est printed uniquement si non quiet - # Le rapport est écrit si pertinent - # Le rapport contient le bon nombre de lignes - - def stop(self): - for patch in self.patches[::-1]: - patch.stop() - self.mocks.pop() - - def strip_log(self, log): - log = log.replace("\n\n", "\n") - return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE) - - def read_log_files(self, log_files): - for log_file in log_files: - with open(log_file, "r") as f: - log = self.strip_log(f.read()).split("\n") - if len(log[-1]) == 0: - log.pop() - self.log_files.append(log) - -class DatabaseMock: - def __init__(self, tester, reports, report_db): - self.tester = tester - self.reports = reports - self.report_db = report_db - self.rows = [] - self.total_report_lines= 0 - self.requests = [] - for user_id, market_id, http_requests, report_lines in self.reports.values(): - self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) ) - self.total_report_lines += len(report_lines) - - def start(self): - connect_mock = mock.Mock() - self.cursor = mock.MagicMock() - connect_mock.cursor.return_value = self.cursor - def _execute(request, *args): - self.requests.append(request) - self.cursor.execute.side_effect = _execute - self.cursor.__iter__.return_value = self.rows - - self.db_patch = mock.patch("psycopg2.connect") - self.db_mock = self.db_patch.start() - self.db_mock.return_value = connect_mock - - def check_calls(self): - if self.report_db: - self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count) - self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count) - else: - self.tester.assertEqual(1, self.db_mock.call_count) - self.tester.assertEqual(1, self.cursor.execute.call_count) - - def stop(self): - self.db_patch.stop() - -class RequestsMock: - def __init__(self, tester): - self.tester = tester - self.reports = tester.requests_by_market() - - self.last_https = {} - self.error_calls = [] - self.mocker = requests_mock.Mocker() - def not_stubbed(*args): - self.error_calls.append([args[0].method, args[0].url]) - raise requests_mock.exceptions.MockException("Not stubbed URL") - self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY, - text=not_stubbed) - - self.mocks = {} - - for market_id, elements in self.reports.items(): - for element in elements: - method = element["method"] - url = element["url"] - self.mocks \ - .setdefault((method, url), {}) \ - .setdefault(market_id, []) \ - .append(element) - - for ((method, url), elements) in self.mocks.items(): - self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True) - - def start(self): - self.mocker.start() - - def stop(self): - self.mocker.stop() - - def check_calls(self): - self.tester.assertEqual([], self.error_calls) - for (method, url), elements in self.mocks.items(): - for market_id, element in elements.items(): - self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id)) - - def clean_body(self, body): - if body is None: - return None - if isinstance(body, bytes): - body = body.decode() - body = re.sub(r"&nonce=\d*$", "", body) - body = re.sub(r"nonce=\d*&?", "", body) - return body - -def callback(self, elements, request, context): - try: - element = elements[request.headers.get("X-market-id")].pop(0) - except (IndexError, KeyError): - self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")]) - raise RuntimeError("Unexpected call") - if element["response"] is None and element["response_same_as"] is not None: - element["response"] = self.last_https[element["response_same_as"]] - elif element["response"] is not None: - self.last_https[element["date"]] = element["response"] - - assert self.clean_body(request.body) == \ - self.clean_body(element["body"]), "Body does not match" - context.status_code = element["status"] - if "error" in element: - if element["error"] == "SSLError": - raise SSLError(element["error_message"]) - else: - raise getattr(requests.exceptions, element["error"])(element["error_message"]) - return element["response"] - -class GlobalVariablesMock: - def start(self): - import market - import store - - self.patchers = [ - mock.patch.multiple(market.Portfolio, - data=store.LockedVar(None), - liquidities=store.LockedVar({}), - last_date=store.LockedVar(None), - report=store.LockedVar(store.ReportStore(None, no_http_dup=True)), - worker=None, - worker_tag="", - worker_notify=None, - worker_started=False, - callback=None) - ] - for patcher in self.patchers: - patcher.start() - - def stop(self): - pass - - -class TimeMock: - delta = 0 - true_time = time.time - true_sleep = time.sleep - time_patch = None - datetime_patch = None - - @classmethod - def start(cls, start_date): - cls.delta = (datetime.datetime.now() - start_date).total_seconds() - - class fake_datetime(datetime.datetime): - @classmethod - def now(cls, tz=None): - if tz is None: - return cls.fromtimestamp(time.time()) - else: - return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz)) - - cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep) - cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime) - cls.time_patch.start() - cls.datetime_patch.start() - - @classmethod - def stop(cls): - cls.delta = 0 - cls.datetime_patch.stop() - cls.time_patch.stop() - - @classmethod - def fake_time(cls): - return cls.true_time() - cls.delta - - @classmethod - def fake_sleep(cls, duration): - cls.delta -= duration - cls.true_sleep(0.2) - -class AcceptanceTestCase(): - def parse_file(self, report_file): - with open(report_file, "rb") as f: - json_content = json.load(f, parse_float=Decimal) - config, user, date, market_id = self.parse_config(json_content) - http_requests = self.parse_requests(json_content) - - return config, user, date, market_id, http_requests, json_content - - def parse_requests(self, json_content): - http_requests = [] - for element in json_content: - if element["type"] != "http_request": - continue - http_requests.append(element) - return http_requests - - def parse_config(self, json_content): - market_info = None - for element in json_content: - if element["type"] != "market": - continue - market_info = element - assert market_info is not None, "Couldn't find market element" - - args = market_info["args"] - config = [] - for arg in ["before", "after", "quiet", "debug"]: - if args.get(arg, False): - config.append("--{}".format(arg)) - for arg in ["parallel", "report_db"]: - if not args.get(arg, False): - config.append("--no-{}".format(arg.replace("_", "-"))) - for action in (args.get("action", []) or []): - config.extend(["--action", action]) - if args.get("report_path") is not None: - config.extend(["--report-path", args.get("report_path")]) - if args.get("user") is not None: - config.extend(["--user", args.get("user")]) - config.extend(["--config", ""]) - - user = market_info["user_id"] - date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f") - market_id = market_info["market_id"] - return config, user, date, market_id - - def requests_by_market(self): - r = { - None: [] - } - got_common = False - for user_id, market_id, http_requests, report_lines in self.reports.values(): - r[str(market_id)] = [] - for http_request in http_requests: - if http_request["market_id"] is None: - if not got_common: - r[None].append(http_request) - else: - r[str(market_id)].append(http_request) - got_common = True - return r - - def setUp(self): - if not hasattr(self, "files"): - raise "This class expects to be inherited with a class defining 'files' variable" - if not hasattr(self, "log_files"): - self.log_files = [] - - self.reports = {} - self.start_date = datetime.datetime.now() - self.config = [] - for f in self.files: - self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f) - if date < self.start_date: - self.start_date = date - self.reports[f] = [user_id, market_id, http_requests, report_lines] - - self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config) - self.requests_mock = RequestsMock(self) - self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self) - self.global_variables_mock = GlobalVariablesMock() - - self.database_mock.start() - self.requests_mock.start() - self.file_mock.start() - self.global_variables_mock.start() - TimeMock.start(self.start_date) - - def base_test(self): - main.main(self.config) - self.requests_mock.check_calls() - self.database_mock.check_calls() - self.file_mock.check_calls() - - def tearDown(self): - TimeMock.stop() - self.global_variables_mock.stop() - self.file_mock.stop() - self.requests_mock.stop() - self.database_mock.stop() - -for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True): - json_files = glob.glob("{}/*.json".format(dirfile)) - log_files = glob.glob("{}/*.log".format(dirfile)) - if len(json_files) > 0: - name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1] - cname = "".join(list(map(lambda x: x.capitalize(), name.split("_")))) - - globals()[cname] = type(cname, - (AcceptanceTestCase,unittest.TestCase), { - "log_files": log_files, - "files": json_files, - "test_{}".format(name): AcceptanceTestCase.base_test - }) - -if __name__ == '__main__': - unittest.main() - diff --git a/tests/acceptance.py b/tests/acceptance.py new file mode 100644 index 0000000..66014ca --- /dev/null +++ b/tests/acceptance.py @@ -0,0 +1,358 @@ +import requests +import requests_mock +import sys, os +import time, datetime +from unittest import mock +from ssl import SSLError +from decimal import Decimal +import simplejson as json +import psycopg2 +from io import StringIO +import re +import functools +import threading + +class TimeMock: + delta = {} + delta_init = 0 + true_time = time.time + true_sleep = time.sleep + time_patch = None + datetime_patch = None + + @classmethod + def travel(cls, start_date): + cls.delta = {} + cls.delta_init = (datetime.datetime.now() - start_date).total_seconds() + + @classmethod + def start(cls): + cls.delta = {} + cls.delta_init = 0 + + class fake_datetime(datetime.datetime): + @classmethod + def now(cls, tz=None): + if tz is None: + return cls.fromtimestamp(time.time()) + else: + return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz)) + + cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep) + cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime) + cls.time_patch.start() + cls.datetime_patch.start() + + @classmethod + def stop(cls): + cls.delta = {} + cls.delta_init = 0 + + @classmethod + def fake_time(cls): + cls.delta.setdefault(threading.current_thread(), cls.delta_init) + return cls.true_time() - cls.delta[threading.current_thread()] + + @classmethod + def fake_sleep(cls, duration): + cls.delta.setdefault(threading.current_thread(), cls.delta_init) + cls.delta[threading.current_thread()] -= float(duration) + cls.true_sleep(min(float(duration), 0.1)) + +TimeMock.start() +import main + +class FileMock: + def __init__(self, log_files, quiet, tester): + self.tester = tester + self.log_files = [] + if log_files is not None and len(log_files) > 0: + self.read_log_files(log_files) + self.quiet = quiet + self.patches = [ + mock.patch("market.open"), + mock.patch("os.makedirs"), + mock.patch("sys.stdout", new_callable=StringIO), + ] + self.mocks = [] + + def start(self): + for patch in self.patches: + self.mocks.append(patch.start()) + self.stdout = self.mocks[-1] + + def check_calls(self): + stdout = self.stdout.getvalue() + if self.quiet: + self.tester.assertEqual("", stdout) + else: + log = self.strip_log(stdout) + if len(self.log_files) != 0: + split_logs = log.split("\n") + self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs)) + try: + for log_file in self.log_files: + for line in log_file: + split_logs.pop(split_logs.index(line)) + except ValueError: + if not line.startswith("[Worker] "): + self.tester.fail("« {} » not found in log file {}".format(line, split_logs)) + # Le fichier de log est écrit + # Le rapport est écrit si pertinent + # Le rapport contient le bon nombre de lignes + + def stop(self): + for patch in self.patches[::-1]: + patch.stop() + self.mocks.pop() + + def strip_log(self, log): + log = log.replace("\n\n", "\n") + return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE) + + def read_log_files(self, log_files): + for log_file in log_files: + with open(log_file, "r") as f: + log = self.strip_log(f.read()).split("\n") + if len(log[-1]) == 0: + log.pop() + self.log_files.append(log) + +class DatabaseMock: + def __init__(self, tester, reports, report_db): + self.tester = tester + self.reports = reports + self.report_db = report_db + self.rows = [] + self.total_report_lines= 0 + self.requests = [] + for user_id, market_id, http_requests, report_lines in self.reports.values(): + self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) ) + self.total_report_lines += len(report_lines) + + def start(self): + connect_mock = mock.Mock() + self.cursor = mock.MagicMock() + connect_mock.cursor.return_value = self.cursor + def _execute(request, *args): + self.requests.append(request) + self.cursor.execute.side_effect = _execute + self.cursor.__iter__.return_value = self.rows + + self.db_patch = mock.patch("psycopg2.connect") + self.db_mock = self.db_patch.start() + self.db_mock.return_value = connect_mock + + def check_calls(self): + if self.report_db: + self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count) + self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count) + else: + self.tester.assertEqual(1, self.db_mock.call_count) + self.tester.assertEqual(1, self.cursor.execute.call_count) + + def stop(self): + self.db_patch.stop() + +class RequestsMock: + def __init__(self, tester): + self.tester = tester + self.reports = tester.requests_by_market() + + self.last_https = {} + self.error_calls = [] + self.mocker = requests_mock.Mocker() + def not_stubbed(*args): + self.error_calls.append([args[0].method, args[0].url]) + raise requests_mock.exceptions.MockException("Not stubbed URL") + self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY, + text=not_stubbed) + + self.mocks = {} + + for market_id, elements in self.reports.items(): + for element in elements: + method = element["method"] + url = element["url"] + self.mocks \ + .setdefault((method, url), {}) \ + .setdefault(market_id, []) \ + .append(element) + + for ((method, url), elements) in self.mocks.items(): + self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True) + + def start(self): + self.mocker.start() + + def stop(self): + self.mocker.stop() + + lazy_calls = [ + "https://cryptoportfolio.io/wp-content/uploads/portfolio/json/cryptoportfolio.json", + "https://poloniex.com/public?command=returnTicker", + ] + def check_calls(self): + self.tester.assertEqual([], self.error_calls) + for (method, url), elements in self.mocks.items(): + for market_id, element in elements.items(): + if url not in self.lazy_calls: + self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id)) + + def clean_body(self, body): + if body is None: + return None + if isinstance(body, bytes): + body = body.decode() + body = re.sub(r"&nonce=\d*$", "", body) + body = re.sub(r"nonce=\d*&?", "", body) + return body + +def callback(self, elements, request, context): + try: + element = elements[request.headers.get("X-market-id")].pop(0) + except (IndexError, KeyError): + self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")]) + raise RuntimeError("Unexpected call") + if element["response"] is None and element["response_same_as"] is not None: + element["response"] = self.last_https[element["response_same_as"]] + elif element["response"] is not None: + self.last_https[element["date"]] = element["response"] + + time.sleep(element.get("duration", 0)) + + assert self.clean_body(request.body) == \ + self.clean_body(element["body"]), "Body does not match" + context.status_code = element["status"] + if "error" in element: + if element["error"] == "SSLError": + raise SSLError(element["error_message"]) + else: + raise getattr(requests.exceptions, element["error"])(element["error_message"]) + return element["response"] + +class GlobalVariablesMock: + def start(self): + import market + import store + + self.patchers = [ + mock.patch.multiple(market.Portfolio, + data=store.LockedVar(None), + liquidities=store.LockedVar({}), + last_date=store.LockedVar(None), + report=store.LockedVar(store.ReportStore(None, no_http_dup=True)), + worker=None, + worker_tag="", + worker_notify=None, + worker_started=False, + callback=None) + ] + for patcher in self.patchers: + patcher.start() + + def stop(self): + pass + + +class AcceptanceTestCase(): + def parse_file(self, report_file): + with open(report_file, "rb") as f: + json_content = json.load(f, parse_float=Decimal) + config, user, date, market_id = self.parse_config(json_content) + http_requests = self.parse_requests(json_content) + + return config, user, date, market_id, http_requests, json_content + + def parse_requests(self, json_content): + http_requests = [] + for element in json_content: + if element["type"] != "http_request": + continue + http_requests.append(element) + return http_requests + + def parse_config(self, json_content): + market_info = None + for element in json_content: + if element["type"] != "market": + continue + market_info = element + assert market_info is not None, "Couldn't find market element" + + args = market_info["args"] + config = [] + for arg in ["before", "after", "quiet", "debug"]: + if args.get(arg, False): + config.append("--{}".format(arg)) + for arg in ["parallel", "report_db"]: + if not args.get(arg, False): + config.append("--no-{}".format(arg.replace("_", "-"))) + for action in (args.get("action", []) or []): + config.extend(["--action", action]) + if args.get("report_path") is not None: + config.extend(["--report-path", args.get("report_path")]) + if args.get("user") is not None: + config.extend(["--user", args.get("user")]) + config.extend(["--config", ""]) + + user = market_info["user_id"] + date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f") + market_id = market_info["market_id"] + return config, user, date, market_id + + def requests_by_market(self): + r = { + None: [] + } + got_common = False + for user_id, market_id, http_requests, report_lines in self.reports.values(): + r[str(market_id)] = [] + for http_request in http_requests: + if http_request["market_id"] is None: + if not got_common: + r[None].append(http_request) + else: + r[str(market_id)].append(http_request) + got_common = True + return r + + def setUp(self): + if not hasattr(self, "files"): + raise "This class expects to be inherited with a class defining 'files' variable" + if not hasattr(self, "log_files"): + self.log_files = [] + + self.reports = {} + self.start_date = datetime.datetime.now() + self.config = [] + for f in self.files: + self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f) + if date < self.start_date: + self.start_date = date + self.reports[f] = [user_id, market_id, http_requests, report_lines] + + self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config) + self.requests_mock = RequestsMock(self) + self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self) + self.global_variables_mock = GlobalVariablesMock() + + self.database_mock.start() + self.requests_mock.start() + self.file_mock.start() + self.global_variables_mock.start() + TimeMock.travel(self.start_date) + + def base_test(self): + main.main(self.config) + self.requests_mock.check_calls() + self.database_mock.check_calls() + self.file_mock.check_calls() + + def tearDown(self): + TimeMock.stop() + self.global_variables_mock.stop() + self.file_mock.stop() + self.requests_mock.stop() + self.database_mock.stop() + diff --git a/tests/helper.py b/tests/helper.py index 4548b16..b85bf3a 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -6,9 +6,19 @@ import requests_mock from io import StringIO import portfolio, market, main, store -__all__ = ["unittest", "WebMockTestCase", "mock", "D", +__all__ = ["limits", "unittest", "WebMockTestCase", "mock", "D", "StringIO"] +limits = ["acceptance", "unit"] +for test_type in limits: + if "--no{}".format(test_type) in sys.argv: + sys.argv.remove("--no{}".format(test_type)) + limits.remove(test_type) + if "--only{}".format(test_type) in sys.argv: + sys.argv.remove("--only{}".format(test_type)) + limits = [test_type] + break + class WebMockTestCase(unittest.TestCase): import time diff --git a/tests/test_acceptance.py b/tests/test_acceptance.py new file mode 100644 index 0000000..77a6cca --- /dev/null +++ b/tests/test_acceptance.py @@ -0,0 +1,25 @@ +from .helper import limits +from tests.acceptance import AcceptanceTestCase + +import unittest +import glob + +__all__ = [] + +for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True): + json_files = glob.glob("{}/*.json".format(dirfile)) + log_files = glob.glob("{}/*.log".format(dirfile)) + if len(json_files) > 0: + name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1] + cname = "".join(list(map(lambda x: x.capitalize(), name.split("_")))) + + globals()[cname] = unittest.skipUnless("acceptance" in limits, "Acceptance skipped")( + type(cname, (AcceptanceTestCase, unittest.TestCase), { + "log_files": log_files, + "files": json_files, + "test_{}".format(name): AcceptanceTestCase.base_test + }) + ) + __all__.append(cname) + + diff --git a/tests/test_ccxt_wrapper.py b/tests/test_ccxt_wrapper.py index 18feab3..10e334d 100644 --- a/tests/test_ccxt_wrapper.py +++ b/tests/test_ccxt_wrapper.py @@ -1,7 +1,8 @@ -from .helper import unittest, mock, D +from .helper import limits, unittest, mock, D import requests_mock import market +@unittest.skipUnless("unit" in limits, "Unit skipped") class poloniexETest(unittest.TestCase): def setUp(self): super().setUp() diff --git a/tests/test_main.py b/tests/test_main.py index cee89ce..d2f8029 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,6 +1,7 @@ from .helper import * import main, market +@unittest.skipUnless("unit" in limits, "Unit skipped") class MainTest(WebMockTestCase): def test_make_order(self): self.m.get_ticker.return_value = { diff --git a/tests/test_market.py b/tests/test_market.py index fd23162..14b23b5 100644 --- a/tests/test_market.py +++ b/tests/test_market.py @@ -2,6 +2,7 @@ from .helper import * import market, store, portfolio import datetime +@unittest.skipUnless("unit" in limits, "Unit skipped") class MarketTest(WebMockTestCase): def setUp(self): super().setUp() @@ -729,6 +730,7 @@ class MarketTest(WebMockTestCase): store_report.assert_called_once() +@unittest.skipUnless("unit" in limits, "Unit skipped") class ProcessorTest(WebMockTestCase): def test_values(self): processor = market.Processor(self.m) diff --git a/tests/test_portfolio.py b/tests/test_portfolio.py index 14dc995..4d78996 100644 --- a/tests/test_portfolio.py +++ b/tests/test_portfolio.py @@ -2,6 +2,7 @@ from .helper import * import portfolio import datetime +@unittest.skipUnless("unit" in limits, "Unit skipped") class ComputationTest(WebMockTestCase): def test_compute_value(self): compute = mock.Mock() @@ -25,6 +26,7 @@ class ComputationTest(WebMockTestCase): portfolio.Computation.compute_value("foo", "bid", compute_value="test") compute.assert_called_with("foo", "bid") +@unittest.skipUnless("unit" in limits, "Unit skipped") class TradeTest(WebMockTestCase): def test_values_assertion(self): @@ -609,6 +611,7 @@ class TradeTest(WebMockTestCase): self.assertEqual("ETH", as_json["currency"]) self.assertEqual("BTC", as_json["base_currency"]) +@unittest.skipUnless("unit" in limits, "Unit skipped") class BalanceTest(WebMockTestCase): def test_values(self): balance = portfolio.Balance("BTC", { @@ -684,6 +687,7 @@ class BalanceTest(WebMockTestCase): self.assertEqual(D(0), as_json["margin_available"]) self.assertEqual(D(0), as_json["margin_borrowed"]) +@unittest.skipUnless("unit" in limits, "Unit skipped") class OrderTest(WebMockTestCase): def test_values(self): order = portfolio.Order("buy", portfolio.Amount("ETH", 10), @@ -1745,6 +1749,7 @@ class OrderTest(WebMockTestCase): result = order.retrieve_order() self.assertFalse(result) +@unittest.skipUnless("unit" in limits, "Unit skipped") class MouvementTest(WebMockTestCase): def test_values(self): mouvement = portfolio.Mouvement("ETH", "BTC", { @@ -1802,6 +1807,7 @@ class MouvementTest(WebMockTestCase): self.assertEqual("BTC", as_json["base_currency"]) self.assertEqual("ETH", as_json["currency"]) +@unittest.skipUnless("unit" in limits, "Unit skipped") class AmountTest(WebMockTestCase): def test_values(self): amount = portfolio.Amount("BTC", "0.65") diff --git a/tests/test_store.py b/tests/test_store.py index e281adb..ffd2645 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -4,12 +4,14 @@ import datetime import threading import market, portfolio, store +@unittest.skipUnless("unit" in limits, "Unit skipped") class NoopLockTest(unittest.TestCase): def test_with(self): noop_lock = store.NoopLock() with noop_lock: self.assertTrue(True) +@unittest.skipUnless("unit" in limits, "Unit skipped") class LockedVarTest(unittest.TestCase): def test_values(self): @@ -61,6 +63,7 @@ class LockedVarTest(unittest.TestCase): thread3.join() self.assertEqual("Bar", locked_var.get()[0:3]) +@unittest.skipUnless("unit" in limits, "Unit skipped") class TradeStoreTest(WebMockTestCase): def test_compute_trades(self): self.m.balances.currencies.return_value = ["XMR", "DASH", "XVG", "BTC", "ETH"] @@ -285,6 +288,7 @@ class TradeStoreTest(WebMockTestCase): self.assertEqual([trade_mock1, trade_mock2], trade_store.pending) +@unittest.skipUnless("unit" in limits, "Unit skipped") class BalanceStoreTest(WebMockTestCase): def setUp(self): super().setUp() @@ -437,6 +441,7 @@ class BalanceStoreTest(WebMockTestCase): self.assertEqual(1, as_json["BTC"]) self.assertEqual(2, as_json["ETH"]) +@unittest.skipUnless("unit" in limits, "Unit skipped") class ReportStoreTest(WebMockTestCase): def test_add_log(self): with self.subTest(market=self.m): @@ -997,6 +1002,7 @@ class ReportStoreTest(WebMockTestCase): 'action': 'Hey' }) +@unittest.skipUnless("unit" in limits, "Unit skipped") class PortfolioTest(WebMockTestCase): def setUp(self): super().setUp() -- cgit v1.2.3