From 3080f31d1ee74104640dcff451922cd0ae88ee22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Bouya?= Date: Fri, 20 Apr 2018 20:09:13 +0200 Subject: Move acceptance tests to common directory --- tests/acceptance.py | 358 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 358 insertions(+) create mode 100644 tests/acceptance.py (limited to 'tests/acceptance.py') diff --git a/tests/acceptance.py b/tests/acceptance.py new file mode 100644 index 0000000..66014ca --- /dev/null +++ b/tests/acceptance.py @@ -0,0 +1,358 @@ +import requests +import requests_mock +import sys, os +import time, datetime +from unittest import mock +from ssl import SSLError +from decimal import Decimal +import simplejson as json +import psycopg2 +from io import StringIO +import re +import functools +import threading + +class TimeMock: + delta = {} + delta_init = 0 + true_time = time.time + true_sleep = time.sleep + time_patch = None + datetime_patch = None + + @classmethod + def travel(cls, start_date): + cls.delta = {} + cls.delta_init = (datetime.datetime.now() - start_date).total_seconds() + + @classmethod + def start(cls): + cls.delta = {} + cls.delta_init = 0 + + class fake_datetime(datetime.datetime): + @classmethod + def now(cls, tz=None): + if tz is None: + return cls.fromtimestamp(time.time()) + else: + return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz)) + + cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep) + cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime) + cls.time_patch.start() + cls.datetime_patch.start() + + @classmethod + def stop(cls): + cls.delta = {} + cls.delta_init = 0 + + @classmethod + def fake_time(cls): + cls.delta.setdefault(threading.current_thread(), cls.delta_init) + return cls.true_time() - cls.delta[threading.current_thread()] + + @classmethod + def fake_sleep(cls, duration): + cls.delta.setdefault(threading.current_thread(), cls.delta_init) + cls.delta[threading.current_thread()] -= float(duration) + cls.true_sleep(min(float(duration), 0.1)) + +TimeMock.start() +import main + +class FileMock: + def __init__(self, log_files, quiet, tester): + self.tester = tester + self.log_files = [] + if log_files is not None and len(log_files) > 0: + self.read_log_files(log_files) + self.quiet = quiet + self.patches = [ + mock.patch("market.open"), + mock.patch("os.makedirs"), + mock.patch("sys.stdout", new_callable=StringIO), + ] + self.mocks = [] + + def start(self): + for patch in self.patches: + self.mocks.append(patch.start()) + self.stdout = self.mocks[-1] + + def check_calls(self): + stdout = self.stdout.getvalue() + if self.quiet: + self.tester.assertEqual("", stdout) + else: + log = self.strip_log(stdout) + if len(self.log_files) != 0: + split_logs = log.split("\n") + self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs)) + try: + for log_file in self.log_files: + for line in log_file: + split_logs.pop(split_logs.index(line)) + except ValueError: + if not line.startswith("[Worker] "): + self.tester.fail("« {} » not found in log file {}".format(line, split_logs)) + # Le fichier de log est écrit + # Le rapport est écrit si pertinent + # Le rapport contient le bon nombre de lignes + + def stop(self): + for patch in self.patches[::-1]: + patch.stop() + self.mocks.pop() + + def strip_log(self, log): + log = log.replace("\n\n", "\n") + return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE) + + def read_log_files(self, log_files): + for log_file in log_files: + with open(log_file, "r") as f: + log = self.strip_log(f.read()).split("\n") + if len(log[-1]) == 0: + log.pop() + self.log_files.append(log) + +class DatabaseMock: + def __init__(self, tester, reports, report_db): + self.tester = tester + self.reports = reports + self.report_db = report_db + self.rows = [] + self.total_report_lines= 0 + self.requests = [] + for user_id, market_id, http_requests, report_lines in self.reports.values(): + self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) ) + self.total_report_lines += len(report_lines) + + def start(self): + connect_mock = mock.Mock() + self.cursor = mock.MagicMock() + connect_mock.cursor.return_value = self.cursor + def _execute(request, *args): + self.requests.append(request) + self.cursor.execute.side_effect = _execute + self.cursor.__iter__.return_value = self.rows + + self.db_patch = mock.patch("psycopg2.connect") + self.db_mock = self.db_patch.start() + self.db_mock.return_value = connect_mock + + def check_calls(self): + if self.report_db: + self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count) + self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count) + else: + self.tester.assertEqual(1, self.db_mock.call_count) + self.tester.assertEqual(1, self.cursor.execute.call_count) + + def stop(self): + self.db_patch.stop() + +class RequestsMock: + def __init__(self, tester): + self.tester = tester + self.reports = tester.requests_by_market() + + self.last_https = {} + self.error_calls = [] + self.mocker = requests_mock.Mocker() + def not_stubbed(*args): + self.error_calls.append([args[0].method, args[0].url]) + raise requests_mock.exceptions.MockException("Not stubbed URL") + self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY, + text=not_stubbed) + + self.mocks = {} + + for market_id, elements in self.reports.items(): + for element in elements: + method = element["method"] + url = element["url"] + self.mocks \ + .setdefault((method, url), {}) \ + .setdefault(market_id, []) \ + .append(element) + + for ((method, url), elements) in self.mocks.items(): + self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True) + + def start(self): + self.mocker.start() + + def stop(self): + self.mocker.stop() + + lazy_calls = [ + "https://cryptoportfolio.io/wp-content/uploads/portfolio/json/cryptoportfolio.json", + "https://poloniex.com/public?command=returnTicker", + ] + def check_calls(self): + self.tester.assertEqual([], self.error_calls) + for (method, url), elements in self.mocks.items(): + for market_id, element in elements.items(): + if url not in self.lazy_calls: + self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id)) + + def clean_body(self, body): + if body is None: + return None + if isinstance(body, bytes): + body = body.decode() + body = re.sub(r"&nonce=\d*$", "", body) + body = re.sub(r"nonce=\d*&?", "", body) + return body + +def callback(self, elements, request, context): + try: + element = elements[request.headers.get("X-market-id")].pop(0) + except (IndexError, KeyError): + self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")]) + raise RuntimeError("Unexpected call") + if element["response"] is None and element["response_same_as"] is not None: + element["response"] = self.last_https[element["response_same_as"]] + elif element["response"] is not None: + self.last_https[element["date"]] = element["response"] + + time.sleep(element.get("duration", 0)) + + assert self.clean_body(request.body) == \ + self.clean_body(element["body"]), "Body does not match" + context.status_code = element["status"] + if "error" in element: + if element["error"] == "SSLError": + raise SSLError(element["error_message"]) + else: + raise getattr(requests.exceptions, element["error"])(element["error_message"]) + return element["response"] + +class GlobalVariablesMock: + def start(self): + import market + import store + + self.patchers = [ + mock.patch.multiple(market.Portfolio, + data=store.LockedVar(None), + liquidities=store.LockedVar({}), + last_date=store.LockedVar(None), + report=store.LockedVar(store.ReportStore(None, no_http_dup=True)), + worker=None, + worker_tag="", + worker_notify=None, + worker_started=False, + callback=None) + ] + for patcher in self.patchers: + patcher.start() + + def stop(self): + pass + + +class AcceptanceTestCase(): + def parse_file(self, report_file): + with open(report_file, "rb") as f: + json_content = json.load(f, parse_float=Decimal) + config, user, date, market_id = self.parse_config(json_content) + http_requests = self.parse_requests(json_content) + + return config, user, date, market_id, http_requests, json_content + + def parse_requests(self, json_content): + http_requests = [] + for element in json_content: + if element["type"] != "http_request": + continue + http_requests.append(element) + return http_requests + + def parse_config(self, json_content): + market_info = None + for element in json_content: + if element["type"] != "market": + continue + market_info = element + assert market_info is not None, "Couldn't find market element" + + args = market_info["args"] + config = [] + for arg in ["before", "after", "quiet", "debug"]: + if args.get(arg, False): + config.append("--{}".format(arg)) + for arg in ["parallel", "report_db"]: + if not args.get(arg, False): + config.append("--no-{}".format(arg.replace("_", "-"))) + for action in (args.get("action", []) or []): + config.extend(["--action", action]) + if args.get("report_path") is not None: + config.extend(["--report-path", args.get("report_path")]) + if args.get("user") is not None: + config.extend(["--user", args.get("user")]) + config.extend(["--config", ""]) + + user = market_info["user_id"] + date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f") + market_id = market_info["market_id"] + return config, user, date, market_id + + def requests_by_market(self): + r = { + None: [] + } + got_common = False + for user_id, market_id, http_requests, report_lines in self.reports.values(): + r[str(market_id)] = [] + for http_request in http_requests: + if http_request["market_id"] is None: + if not got_common: + r[None].append(http_request) + else: + r[str(market_id)].append(http_request) + got_common = True + return r + + def setUp(self): + if not hasattr(self, "files"): + raise "This class expects to be inherited with a class defining 'files' variable" + if not hasattr(self, "log_files"): + self.log_files = [] + + self.reports = {} + self.start_date = datetime.datetime.now() + self.config = [] + for f in self.files: + self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f) + if date < self.start_date: + self.start_date = date + self.reports[f] = [user_id, market_id, http_requests, report_lines] + + self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config) + self.requests_mock = RequestsMock(self) + self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self) + self.global_variables_mock = GlobalVariablesMock() + + self.database_mock.start() + self.requests_mock.start() + self.file_mock.start() + self.global_variables_mock.start() + TimeMock.travel(self.start_date) + + def base_test(self): + main.main(self.config) + self.requests_mock.check_calls() + self.database_mock.check_calls() + self.file_mock.check_calls() + + def tearDown(self): + TimeMock.stop() + self.global_variables_mock.stop() + self.file_mock.stop() + self.requests_mock.stop() + self.database_mock.stop() + -- cgit v1.2.3