import requests import requests_mock import sys, os import time, datetime import unittest from unittest import mock from ssl import SSLError from decimal import Decimal import simplejson as json import psycopg2 from io import StringIO class FileMock: @classmethod def start(cls): cls.file_mock = mock.patch("market.open") cls.os_mock = mock.patch("os.makedirs") cls.stdout_mock = mock.patch('sys.stdout', new_callable=StringIO) cls.stdout_mock.start() cls.os_mock.start() cls.file_mock.start() @classmethod def check_calls(cls, tester): pass #raise NotImplementedError("Todo") @classmethod def stop(cls): cls.file_mock.stop() cls.stdout_mock.stop() cls.os_mock.stop() class DatabaseMock: rows = [] report_db = False db_patch = None cursor = None total_report_lines = 0 db_mock = None requests = [] @classmethod def start(cls, reports, report_db): cls.report_db = report_db cls.rows = [] cls.total_report_lines= 0 for user_id, market_id, http_requests, report_lines in reports.values(): cls.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) ) cls.total_report_lines += len(report_lines) connect_mock = mock.Mock() cls.cursor = mock.MagicMock() connect_mock.cursor.return_value = cls.cursor def _execute(request, *args): cls.requests.append(request) cls.cursor.execute.side_effect = _execute cls.cursor.__iter__.return_value = cls.rows cls.db_patch = mock.patch("psycopg2.connect") cls.db_mock = cls.db_patch.start() cls.db_mock.return_value = connect_mock @classmethod def check_calls(cls, tester): if cls.report_db: tester.assertEqual(1 + len(cls.rows), cls.db_mock.call_count) tester.assertEqual(1 + len(cls.rows) + cls.total_report_lines, cls.cursor.execute.call_count) else: tester.assertEqual(1, cls.db_mock.call_count) tester.assertEqual(1, cls.cursor.execute.call_count) @classmethod def stop(cls): cls.db_patch.stop() cls.db_mock = None cls.cursor = None cls.total_report_lines = 0 cls.rows = [] cls.report_db = False cls.requests = [] class RequestsMock: adapter = None mocks = {} last_https = {} request_patch = [] @classmethod def start(cls, reports): cls.adapter = requests_mock.Adapter() cls.adapter.register_uri(requests_mock.ANY, requests_mock.ANY, exc=requests_mock.exceptions.MockException("Not stubbed URL")) true_session = requests.Session cls.mocks = {} for market_id, elements in reports.items(): for element in elements: method = element["method"] url = element["url"] cls.mocks \ .setdefault((method, url), {}) \ .setdefault(market_id, []) \ .append(element) for ((method, url), elements) in cls.mocks.items(): cls.adapter.register_uri(method, url, text=cls.callback_func(elements), complete_qs=True) def _session(): session = true_session() session.get_adapter = lambda url: cls.adapter return session cls.request_patch = [ mock.patch.object(requests.sessions, "Session", new=_session), mock.patch.object(requests, "Session", new=_session) ] for patch in cls.request_patch: patch.start() @classmethod def stop(cls): for patch in cls.request_patch: patch.stop() cls.request_patch = [] cls.last_https = {} cls.mocks = {} cls.adapter = None @classmethod def check_calls(cls, tester): for (method, url), elements in cls.mocks.items(): for market_id, element in elements.items(): tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id)) @classmethod def clean_body(cls, body): if body is None: return None import re if isinstance(body, bytes): body = body.decode() body = re.sub(r"&nonce=\d*$", "", body) body = re.sub(r"nonce=\d*&?", "", body) return body @classmethod def callback_func(cls, elements): def callback(request, context): try: element = elements[request.headers.get("X-market-id")].pop(0) except (IndexError, KeyError): raise RuntimeError("Unexpected call") if element["response"] is None and element["response_same_as"] is not None: element["response"] = cls.last_https[element["response_same_as"]] elif element["response"] is not None: cls.last_https[element["date"]] = element["response"] assert cls.clean_body(request.body) == \ cls.clean_body(element["body"]), "Body does not match" context.status_code = element["status"] if "error" in element: if element["error"] == "SSLError": raise SSLError(element["error_message"]) else: raise getattr(requests.exceptions, element["error"])(element["error_message"]) return element["response"] return callback class TimeMock: delta = 0 true_time = time.time time_patch = None datetime_patch = None @classmethod def start(cls, start_date): cls.delta = (datetime.datetime.now() - start_date).total_seconds() class fake_datetime(datetime.datetime): @classmethod def now(cls, tz=None): if tz is None: return cls.fromtimestamp(time.time()) else: return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz)) cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep) cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime) cls.time_patch.start() cls.datetime_patch.start() @classmethod def stop(cls): cls.delta = 0 cls.datetime_patch.stop() cls.time_patch.stop() @classmethod def fake_time(cls): return cls.true_time() - cls.delta @classmethod def fake_sleep(cls, duration): cls.delta -= duration class AcceptanceTestCase(): def parse_file(self, report_file): with open(report_file, "rb") as f: json_content = json.load(f, parse_float=Decimal) config, user, date, market_id = self.parse_config(json_content) http_requests = self.parse_requests(json_content) return config, user, date, market_id, http_requests, json_content def parse_requests(self, json_content): http_requests = [] for element in json_content: if element["type"] != "http_request": continue http_requests.append(element) return http_requests def parse_config(self, json_content): market_info = None for element in json_content: if element["type"] != "market": continue market_info = element assert market_info is not None, "Couldn't find market element" args = market_info["args"] config = [] for arg in ["before", "after", "quiet", "debug"]: if args.get(arg, False): config.append("--{}".format(arg)) for arg in ["parallel", "report_db"]: if not args.get(arg, False): config.append("--no-{}".format(arg.replace("_", "-"))) for action in args.get("action", []): config.extend(["--action", action]) if args.get("report_path") is not None: config.extend(["--report-path", args.get("report_path")]) if args.get("user") is not None: config.extend(["--user", args.get("user")]) config.extend(["--config", ""]) user = market_info["user_id"] date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f") market_id = market_info["market_id"] return config, user, date, market_id def requests_by_market(self): r = { None: [] } got_common = False for user_id, market_id, http_requests, report_lines in self.reports.values(): r[str(market_id)] = [] for http_request in http_requests: if http_request["market_id"] is None: if not got_common: r[None].append(http_request) else: r[str(market_id)].append(http_request) got_common = True return r def setUp(self): if not hasattr(self, "files"): raise "This class expects to be inherited with a class defining self.files in setUp" self.reports = {} self.start_date = datetime.datetime.now() self.config = [] for f in self.files: self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f) if date < self.start_date: self.start_date = date self.reports[f] = [user_id, market_id, http_requests, report_lines] DatabaseMock.start(self.reports, "--no-report-db" not in self.config) RequestsMock.start(self.requests_by_market()) FileMock.start() TimeMock.start(self.start_date) def base_test(self): import main main.main(self.config) RequestsMock.check_calls(self) DatabaseMock.check_calls(self) FileMock.check_calls(self) def tearDown(self): TimeMock.stop() FileMock.stop() RequestsMock.stop() DatabaseMock.stop() import glob for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True): json_files = glob.glob("{}/*.json".format(dirfile)) if len(json_files) > 0: name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1] cname = "".join(list(map(lambda x: x.capitalize(), name.split("_")))) globals()[cname] = type(cname, (AcceptanceTestCase,unittest.TestCase), { "files": json_files, "test_{}".format(name): AcceptanceTestCase.base_test }) if __name__ == '__main__': unittest.main()