import requests import requests_mock import sys, os import time, datetime from unittest import mock from ssl import SSLError from decimal import Decimal import simplejson as json import psycopg2 from io import StringIO import re import functools import threading class TimeMock: delta = {} delta_init = 0 true_time = time.time true_sleep = time.sleep time_patch = None datetime_patch = None @classmethod def travel(cls, start_date): cls.delta = {} cls.delta_init = (datetime.datetime.now() - start_date).total_seconds() @classmethod def start(cls): cls.delta = {} cls.delta_init = 0 class fake_datetime(datetime.datetime): @classmethod def now(cls, tz=None): if tz is None: return cls.fromtimestamp(time.time()) else: return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz)) cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep) cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime) cls.time_patch.start() cls.datetime_patch.start() @classmethod def stop(cls): cls.delta = {} cls.delta_init = 0 @classmethod def fake_time(cls): cls.delta.setdefault(threading.current_thread(), cls.delta_init) return cls.true_time() - cls.delta[threading.current_thread()] @classmethod def fake_sleep(cls, duration): cls.delta.setdefault(threading.current_thread(), cls.delta_init) cls.delta[threading.current_thread()] -= float(duration) cls.true_sleep(min(float(duration), 0.1)) TimeMock.start() import main class FileMock: def __init__(self, log_files, quiet, tester): self.tester = tester self.log_files = [] if log_files is not None and len(log_files) > 0: self.read_log_files(log_files) self.quiet = quiet self.patches = [ mock.patch("market.open"), mock.patch("os.makedirs"), mock.patch("sys.stdout", new_callable=StringIO), ] self.mocks = [] def start(self): for patch in self.patches: self.mocks.append(patch.start()) self.stdout = self.mocks[-1] def check_calls(self): stdout = self.stdout.getvalue() if self.quiet: self.tester.assertEqual("", stdout) else: log = self.strip_log(stdout) if len(self.log_files) != 0: split_logs = log.split("\n") self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs)) try: for log_file in self.log_files: for line in log_file: split_logs.pop(split_logs.index(line)) except ValueError: if not line.startswith("[Worker] "): self.tester.fail("« {} » not found in log file {}".format(line, split_logs)) # Le fichier de log est écrit # Le rapport est écrit si pertinent # Le rapport contient le bon nombre de lignes def stop(self): for patch in self.patches[::-1]: patch.stop() self.mocks.pop() def strip_log(self, log): log = log.replace("\n\n", "\n") return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE) def read_log_files(self, log_files): for log_file in log_files: with open(log_file, "r") as f: log = self.strip_log(f.read()).split("\n") if len(log[-1]) == 0: log.pop() self.log_files.append(log) class DatabaseMock: def __init__(self, tester, reports, report_db): self.tester = tester self.reports = reports self.report_db = report_db self.rows = [] self.total_report_lines= 0 self.requests = [] for user_id, market_id, http_requests, report_lines in self.reports.values(): self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) ) self.total_report_lines += len(report_lines) def start(self): connect_mock = mock.Mock() self.cursor = mock.MagicMock() connect_mock.cursor.return_value = self.cursor def _execute(request, *args): self.requests.append(request) self.cursor.execute.side_effect = _execute self.cursor.__iter__.return_value = self.rows self.db_patch = mock.patch("psycopg2.connect") self.db_mock = self.db_patch.start() self.db_mock.return_value = connect_mock def check_calls(self): if self.report_db: self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count) self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count) else: self.tester.assertEqual(1, self.db_mock.call_count) self.tester.assertEqual(1, self.cursor.execute.call_count) def stop(self): self.db_patch.stop() class RequestsMock: def __init__(self, tester): self.tester = tester self.reports = tester.requests_by_market() self.last_https = {} self.error_calls = [] self.mocker = requests_mock.Mocker() def not_stubbed(*args): self.error_calls.append([args[0].method, args[0].url]) raise requests_mock.exceptions.MockException("Not stubbed URL") self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY, text=not_stubbed) self.mocks = {} for market_id, elements in self.reports.items(): for element in elements: method = element["method"] url = element["url"] self.mocks \ .setdefault((method, url), {}) \ .setdefault(market_id, []) \ .append(element) for ((method, url), elements) in self.mocks.items(): self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True) def start(self): self.mocker.start() def stop(self): self.mocker.stop() lazy_calls = [ "https://cryptoportfolio.io/wp-content/uploads/portfolio/json/cryptoportfolio.json", "https://poloniex.com/public?command=returnTicker", ] def check_calls(self): self.tester.assertEqual([], self.error_calls) for (method, url), elements in self.mocks.items(): for market_id, element in elements.items(): if url not in self.lazy_calls: self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id)) def clean_body(self, body): if body is None: return None if isinstance(body, bytes): body = body.decode() body = re.sub(r"&nonce=\d*$", "", body) body = re.sub(r"nonce=\d*&?", "", body) return body def callback(self, elements, request, context): try: element = elements[request.headers.get("X-market-id")].pop(0) except (IndexError, KeyError): self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")]) raise RuntimeError("Unexpected call") if element["response"] is None and element["response_same_as"] is not None: element["response"] = self.last_https[element["response_same_as"]] elif element["response"] is not None: self.last_https[element["date"]] = element["response"] time.sleep(element.get("duration", 0)) assert self.clean_body(request.body) == \ self.clean_body(element["body"]), "Body does not match" context.status_code = element["status"] if "error" in element: if element["error"] == "SSLError": raise SSLError(element["error_message"]) else: raise getattr(requests.exceptions, element["error"])(element["error_message"]) return element["response"] class GlobalVariablesMock: def start(self): import market import store self.patchers = [ mock.patch.multiple(market.Portfolio, data=store.LockedVar(None), liquidities=store.LockedVar({}), last_date=store.LockedVar(None), report=store.LockedVar(store.ReportStore(None, no_http_dup=True)), worker=None, worker_tag="", worker_notify=None, worker_started=False, callback=None) ] for patcher in self.patchers: patcher.start() def stop(self): pass class AcceptanceTestCase(): def parse_file(self, report_file): with open(report_file, "rb") as f: json_content = json.load(f, parse_float=Decimal) config, user, date, market_id = self.parse_config(json_content) http_requests = self.parse_requests(json_content) return config, user, date, market_id, http_requests, json_content def parse_requests(self, json_content): http_requests = [] for element in json_content: if element["type"] != "http_request": continue http_requests.append(element) return http_requests def parse_config(self, json_content): market_info = None for element in json_content: if element["type"] != "market": continue market_info = element assert market_info is not None, "Couldn't find market element" args = market_info["args"] config = [] for arg in ["before", "after", "quiet", "debug"]: if args.get(arg, False): config.append("--{}".format(arg)) for arg in ["parallel", "report_db"]: if not args.get(arg, False): config.append("--no-{}".format(arg.replace("_", "-"))) for action in (args.get("action", []) or []): config.extend(["--action", action]) if args.get("report_path") is not None: config.extend(["--report-path", args.get("report_path")]) if args.get("user") is not None: config.extend(["--user", args.get("user")]) config.extend(["--config", ""]) user = market_info["user_id"] date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f") market_id = market_info["market_id"] return config, user, date, market_id def requests_by_market(self): r = { None: [] } got_common = False for user_id, market_id, http_requests, report_lines in self.reports.values(): r[str(market_id)] = [] for http_request in http_requests: if http_request["market_id"] is None: if not got_common: r[None].append(http_request) else: r[str(market_id)].append(http_request) got_common = True return r def setUp(self): if not hasattr(self, "files"): raise "This class expects to be inherited with a class defining 'files' variable" if not hasattr(self, "log_files"): self.log_files = [] self.reports = {} self.start_date = datetime.datetime.now() self.config = [] for f in self.files: self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f) if date < self.start_date: self.start_date = date self.reports[f] = [user_id, market_id, http_requests, report_lines] self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config) self.requests_mock = RequestsMock(self) self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self) self.global_variables_mock = GlobalVariablesMock() self.database_mock.start() self.requests_mock.start() self.file_mock.start() self.global_variables_mock.start() TimeMock.travel(self.start_date) def base_test(self): main.main(self.config) self.requests_mock.check_calls() self.database_mock.check_calls() self.file_mock.check_calls() def tearDown(self): TimeMock.stop() self.global_variables_mock.stop() self.file_mock.stop() self.requests_mock.stop() self.database_mock.stop()