X-Git-Url: https://git.immae.eu/?p=perso%2FImmae%2FProjets%2FCryptomonnaies%2FCryptoportfolio%2FTrader.git;a=blobdiff_plain;f=test_acceptance.py;h=88a2dd4542b064c543efc813efa2d1d6b40b85dd;hb=1d72880c097ea8259ce9cc63cfe55e6cc7516bd2;hpb=e7d7c0e5645da35adcbfec9e51deb68f012c422f diff --git a/test_acceptance.py b/test_acceptance.py new file mode 100644 index 0000000..88a2dd4 --- /dev/null +++ b/test_acceptance.py @@ -0,0 +1,314 @@ +import requests +import requests_mock +import sys, os +import time, datetime +import unittest +from unittest import mock +from ssl import SSLError +from decimal import Decimal +import simplejson as json +import psycopg2 +from io import StringIO + +class FileMock: + @classmethod + def start(cls): + cls.file_mock = mock.patch("market.open") + cls.os_mock = mock.patch("os.makedirs") + cls.stdout_mock = mock.patch('sys.stdout', new_callable=StringIO) + cls.stdout_mock.start() + cls.os_mock.start() + cls.file_mock.start() + + @classmethod + def check_calls(cls, tester): + pass + #raise NotImplementedError("Todo") + + @classmethod + def stop(cls): + cls.file_mock.stop() + cls.stdout_mock.stop() + cls.os_mock.stop() + +class DatabaseMock: + rows = [] + report_db = False + db_patch = None + cursor = None + total_report_lines = 0 + db_mock = None + requests = [] + + @classmethod + def start(cls, reports, report_db): + cls.report_db = report_db + cls.rows = [] + cls.total_report_lines= 0 + for user_id, market_id, http_requests, report_lines in reports.values(): + cls.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) ) + cls.total_report_lines += len(report_lines) + + connect_mock = mock.Mock() + cls.cursor = mock.MagicMock() + connect_mock.cursor.return_value = cls.cursor + def _execute(request, *args): + cls.requests.append(request) + cls.cursor.execute.side_effect = _execute + cls.cursor.__iter__.return_value = cls.rows + + cls.db_patch = mock.patch("psycopg2.connect") + cls.db_mock = cls.db_patch.start() + cls.db_mock.return_value = connect_mock + + @classmethod + def check_calls(cls, tester): + if cls.report_db: + tester.assertEqual(1 + len(cls.rows), cls.db_mock.call_count) + tester.assertEqual(1 + len(cls.rows) + cls.total_report_lines, cls.cursor.execute.call_count) + else: + tester.assertEqual(1, cls.db_mock.call_count) + tester.assertEqual(1, cls.cursor.execute.call_count) + + @classmethod + def stop(cls): + cls.db_patch.stop() + cls.db_mock = None + cls.cursor = None + cls.total_report_lines = 0 + cls.rows = [] + cls.report_db = False + cls.requests = [] + +class RequestsMock: + adapter = None + mocks = {} + last_https = {} + request_patch = [] + + @classmethod + def start(cls, reports): + cls.adapter = requests_mock.Adapter() + cls.adapter.register_uri(requests_mock.ANY, requests_mock.ANY, + exc=requests_mock.exceptions.MockException("Not stubbed URL")) + true_session = requests.Session + + cls.mocks = {} + + for market_id, elements in reports.items(): + for element in elements: + method = element["method"] + url = element["url"] + cls.mocks \ + .setdefault((method, url), {}) \ + .setdefault(market_id, []) \ + .append(element) + + for ((method, url), elements) in cls.mocks.items(): + cls.adapter.register_uri(method, url, text=cls.callback_func(elements), complete_qs=True) + def _session(): + session = true_session() + session.get_adapter = lambda url: cls.adapter + return session + cls.request_patch = [ + mock.patch.object(requests.sessions, "Session", new=_session), + mock.patch.object(requests, "Session", new=_session) + ] + for patch in cls.request_patch: + patch.start() + + @classmethod + def stop(cls): + for patch in cls.request_patch: + patch.stop() + cls.request_patch = [] + cls.last_https = {} + cls.mocks = {} + cls.adapter = None + + @classmethod + def check_calls(cls, tester): + for (method, url), elements in cls.mocks.items(): + for market_id, element in elements.items(): + tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id)) + + @classmethod + def clean_body(cls, body): + if body is None: + return None + import re + if isinstance(body, bytes): + body = body.decode() + body = re.sub(r"&nonce=\d*$", "", body) + body = re.sub(r"nonce=\d*&?", "", body) + return body + + @classmethod + def callback_func(cls, elements): + def callback(request, context): + try: + element = elements[request.headers.get("X-market-id")].pop(0) + except (IndexError, KeyError): + raise RuntimeError("Unexpected call") + if element["response"] is None and element["response_same_as"] is not None: + element["response"] = cls.last_https[element["response_same_as"]] + elif element["response"] is not None: + cls.last_https[element["date"]] = element["response"] + + assert cls.clean_body(request.body) == \ + cls.clean_body(element["body"]), "Body does not match" + context.status_code = element["status"] + if "error" in element: + if element["error"] == "SSLError": + raise SSLError(element["error_message"]) + else: + raise getattr(requests.exceptions, element["error"])(element["error_message"]) + return element["response"] + return callback + +class TimeMock: + delta = 0 + true_time = time.time + time_patch = None + datetime_patch = None + + @classmethod + def start(cls, start_date): + cls.delta = (datetime.datetime.now() - start_date).total_seconds() + + class fake_datetime(datetime.datetime): + @classmethod + def now(cls, tz=None): + if tz is None: + return cls.fromtimestamp(time.time()) + else: + return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz)) + + cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep) + cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime) + cls.time_patch.start() + cls.datetime_patch.start() + + @classmethod + def stop(cls): + cls.delta = 0 + cls.datetime_patch.stop() + cls.time_patch.stop() + + @classmethod + def fake_time(cls): + return cls.true_time() - cls.delta + + @classmethod + def fake_sleep(cls, duration): + cls.delta -= duration + +class AcceptanceTestCase(): + def parse_file(self, report_file): + with open(report_file, "rb") as f: + json_content = json.load(f, parse_float=Decimal) + config, user, date, market_id = self.parse_config(json_content) + http_requests = self.parse_requests(json_content) + + return config, user, date, market_id, http_requests, json_content + + def parse_requests(self, json_content): + http_requests = [] + for element in json_content: + if element["type"] != "http_request": + continue + http_requests.append(element) + return http_requests + + def parse_config(self, json_content): + market_info = None + for element in json_content: + if element["type"] != "market": + continue + market_info = element + assert market_info is not None, "Couldn't find market element" + + args = market_info["args"] + config = [] + for arg in ["before", "after", "quiet", "debug"]: + if args.get(arg, False): + config.append("--{}".format(arg)) + for arg in ["parallel", "report_db"]: + if not args.get(arg, False): + config.append("--no-{}".format(arg.replace("_", "-"))) + for action in args.get("action", []): + config.extend(["--action", action]) + if args.get("report_path") is not None: + config.extend(["--report-path", args.get("report_path")]) + if args.get("user") is not None: + config.extend(["--user", args.get("user")]) + config.extend(["--config", ""]) + + user = market_info["user_id"] + date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f") + market_id = market_info["market_id"] + return config, user, date, market_id + + def requests_by_market(self): + r = { + None: [] + } + got_common = False + for user_id, market_id, http_requests, report_lines in self.reports.values(): + r[str(market_id)] = [] + for http_request in http_requests: + if http_request["market_id"] is None: + if not got_common: + r[None].append(http_request) + else: + r[str(market_id)].append(http_request) + got_common = True + return r + + def setUp(self): + if not hasattr(self, "files"): + raise "This class expects to be inherited with a class defining self.files in setUp" + + self.reports = {} + self.start_date = datetime.datetime.now() + self.config = [] + for f in self.files: + self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f) + if date < self.start_date: + self.start_date = date + self.reports[f] = [user_id, market_id, http_requests, report_lines] + + DatabaseMock.start(self.reports, "--no-report-db" not in self.config) + RequestsMock.start(self.requests_by_market()) + FileMock.start() + TimeMock.start(self.start_date) + + def base_test(self): + import main + main.main(self.config) + RequestsMock.check_calls(self) + DatabaseMock.check_calls(self) + FileMock.check_calls(self) + + def tearDown(self): + TimeMock.stop() + FileMock.stop() + RequestsMock.stop() + DatabaseMock.stop() + +import glob +for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True): + json_files = glob.glob("{}/*.json".format(dirfile)) + if len(json_files) > 0: + name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1] + cname = "".join(list(map(lambda x: x.capitalize(), name.split("_")))) + + globals()[cname] = type(cname, + (AcceptanceTestCase,unittest.TestCase), { + "files": json_files, + "test_{}".format(name): AcceptanceTestCase.base_test + }) + +if __name__ == '__main__': + unittest.main() +