X-Git-Url: https://git.immae.eu/?a=blobdiff_plain;f=test_acceptance.py;h=3633928c455b5cec8e8cbbeae35b68a44ebd048d;hb=a0dcf4e0978331709da164fb0e29ae008b90fc88;hp=88a2dd4542b064c543efc813efa2d1d6b40b85dd;hpb=c8df27385e02b22d36b240fe29532e97dbba1f43;p=perso%2FImmae%2FProjets%2FCryptomonnaies%2FCryptoportfolio%2FTrader.git diff --git a/test_acceptance.py b/test_acceptance.py index 88a2dd4..3633928 100644 --- a/test_acceptance.py +++ b/test_acceptance.py @@ -9,166 +9,203 @@ from decimal import Decimal import simplejson as json import psycopg2 from io import StringIO +import re +import functools +import glob + +import main class FileMock: - @classmethod - def start(cls): - cls.file_mock = mock.patch("market.open") - cls.os_mock = mock.patch("os.makedirs") - cls.stdout_mock = mock.patch('sys.stdout', new_callable=StringIO) - cls.stdout_mock.start() - cls.os_mock.start() - cls.file_mock.start() + def __init__(self, log_files, quiet, tester): + self.tester = tester + self.log_files = [] + if log_files is not None and len(log_files) > 0: + self.read_log_files(log_files) + self.quiet = quiet + self.patches = [ + mock.patch("market.open"), + mock.patch("os.makedirs"), + mock.patch("sys.stdout", new_callable=StringIO), + ] + self.mocks = [] - @classmethod - def check_calls(cls, tester): - pass - #raise NotImplementedError("Todo") + def start(self): + for patch in self.patches: + self.mocks.append(patch.start()) + self.stdout = self.mocks[-1] - @classmethod - def stop(cls): - cls.file_mock.stop() - cls.stdout_mock.stop() - cls.os_mock.stop() + def check_calls(self): + stdout = self.stdout.getvalue() + if self.quiet: + self.tester.assertEqual("", stdout) + else: + log = self.strip_log(stdout) + if len(self.log_files) != 0: + split_logs = log.split("\n") + self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs)) + try: + for log_file in self.log_files: + for line in log_file: + split_logs.pop(split_logs.index(line)) + except ValueError: + if not line.startswith("[Worker] "): + self.tester.fail("« {} » not found in log file {}".format(line, split_logs)) + # Le fichier de log est écrit + # Le fichier de log est printed uniquement si non quiet + # Le rapport est écrit si pertinent + # Le rapport contient le bon nombre de lignes + + def stop(self): + for patch in self.patches[::-1]: + patch.stop() + self.mocks.pop() -class DatabaseMock: - rows = [] - report_db = False - db_patch = None - cursor = None - total_report_lines = 0 - db_mock = None - requests = [] + def strip_log(self, log): + log = log.replace("\n\n", "\n") + return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE) - @classmethod - def start(cls, reports, report_db): - cls.report_db = report_db - cls.rows = [] - cls.total_report_lines= 0 - for user_id, market_id, http_requests, report_lines in reports.values(): - cls.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) ) - cls.total_report_lines += len(report_lines) + def read_log_files(self, log_files): + for log_file in log_files: + with open(log_file, "r") as f: + log = self.strip_log(f.read()).split("\n") + if len(log[-1]) == 0: + log.pop() + self.log_files.append(log) + +class DatabaseMock: + def __init__(self, tester, reports, report_db): + self.tester = tester + self.reports = reports + self.report_db = report_db + self.rows = [] + self.total_report_lines= 0 + self.requests = [] + for user_id, market_id, http_requests, report_lines in self.reports.values(): + self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) ) + self.total_report_lines += len(report_lines) + def start(self): connect_mock = mock.Mock() - cls.cursor = mock.MagicMock() - connect_mock.cursor.return_value = cls.cursor + self.cursor = mock.MagicMock() + connect_mock.cursor.return_value = self.cursor def _execute(request, *args): - cls.requests.append(request) - cls.cursor.execute.side_effect = _execute - cls.cursor.__iter__.return_value = cls.rows - - cls.db_patch = mock.patch("psycopg2.connect") - cls.db_mock = cls.db_patch.start() - cls.db_mock.return_value = connect_mock - - @classmethod - def check_calls(cls, tester): - if cls.report_db: - tester.assertEqual(1 + len(cls.rows), cls.db_mock.call_count) - tester.assertEqual(1 + len(cls.rows) + cls.total_report_lines, cls.cursor.execute.call_count) + self.requests.append(request) + self.cursor.execute.side_effect = _execute + self.cursor.__iter__.return_value = self.rows + + self.db_patch = mock.patch("psycopg2.connect") + self.db_mock = self.db_patch.start() + self.db_mock.return_value = connect_mock + + def check_calls(self): + if self.report_db: + self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count) + self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count) else: - tester.assertEqual(1, cls.db_mock.call_count) - tester.assertEqual(1, cls.cursor.execute.call_count) + self.tester.assertEqual(1, self.db_mock.call_count) + self.tester.assertEqual(1, self.cursor.execute.call_count) - @classmethod - def stop(cls): - cls.db_patch.stop() - cls.db_mock = None - cls.cursor = None - cls.total_report_lines = 0 - cls.rows = [] - cls.report_db = False - cls.requests = [] + def stop(self): + self.db_patch.stop() class RequestsMock: - adapter = None - mocks = {} - last_https = {} - request_patch = [] - - @classmethod - def start(cls, reports): - cls.adapter = requests_mock.Adapter() - cls.adapter.register_uri(requests_mock.ANY, requests_mock.ANY, - exc=requests_mock.exceptions.MockException("Not stubbed URL")) - true_session = requests.Session - - cls.mocks = {} - - for market_id, elements in reports.items(): + def __init__(self, tester): + self.tester = tester + self.reports = tester.requests_by_market() + + self.last_https = {} + self.error_calls = [] + self.mocker = requests_mock.Mocker() + def not_stubbed(*args): + self.error_calls.append([args[0].method, args[0].url]) + raise requests_mock.exceptions.MockException("Not stubbed URL") + self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY, + text=not_stubbed) + + self.mocks = {} + + for market_id, elements in self.reports.items(): for element in elements: method = element["method"] url = element["url"] - cls.mocks \ + self.mocks \ .setdefault((method, url), {}) \ .setdefault(market_id, []) \ .append(element) - for ((method, url), elements) in cls.mocks.items(): - cls.adapter.register_uri(method, url, text=cls.callback_func(elements), complete_qs=True) - def _session(): - session = true_session() - session.get_adapter = lambda url: cls.adapter - return session - cls.request_patch = [ - mock.patch.object(requests.sessions, "Session", new=_session), - mock.patch.object(requests, "Session", new=_session) - ] - for patch in cls.request_patch: - patch.start() + for ((method, url), elements) in self.mocks.items(): + self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True) - @classmethod - def stop(cls): - for patch in cls.request_patch: - patch.stop() - cls.request_patch = [] - cls.last_https = {} - cls.mocks = {} - cls.adapter = None + def start(self): + self.mocker.start() - @classmethod - def check_calls(cls, tester): - for (method, url), elements in cls.mocks.items(): + def stop(self): + self.mocker.stop() + + def check_calls(self): + self.tester.assertEqual([], self.error_calls) + for (method, url), elements in self.mocks.items(): for market_id, element in elements.items(): - tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id)) + self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id)) - @classmethod - def clean_body(cls, body): + def clean_body(self, body): if body is None: return None - import re if isinstance(body, bytes): body = body.decode() body = re.sub(r"&nonce=\d*$", "", body) body = re.sub(r"nonce=\d*&?", "", body) return body - @classmethod - def callback_func(cls, elements): - def callback(request, context): - try: - element = elements[request.headers.get("X-market-id")].pop(0) - except (IndexError, KeyError): - raise RuntimeError("Unexpected call") - if element["response"] is None and element["response_same_as"] is not None: - element["response"] = cls.last_https[element["response_same_as"]] - elif element["response"] is not None: - cls.last_https[element["date"]] = element["response"] - - assert cls.clean_body(request.body) == \ - cls.clean_body(element["body"]), "Body does not match" - context.status_code = element["status"] - if "error" in element: - if element["error"] == "SSLError": - raise SSLError(element["error_message"]) - else: - raise getattr(requests.exceptions, element["error"])(element["error_message"]) - return element["response"] - return callback +def callback(self, elements, request, context): + try: + element = elements[request.headers.get("X-market-id")].pop(0) + except (IndexError, KeyError): + self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")]) + raise RuntimeError("Unexpected call") + if element["response"] is None and element["response_same_as"] is not None: + element["response"] = self.last_https[element["response_same_as"]] + elif element["response"] is not None: + self.last_https[element["date"]] = element["response"] + + assert self.clean_body(request.body) == \ + self.clean_body(element["body"]), "Body does not match" + context.status_code = element["status"] + if "error" in element: + if element["error"] == "SSLError": + raise SSLError(element["error_message"]) + else: + raise getattr(requests.exceptions, element["error"])(element["error_message"]) + return element["response"] + +class GlobalVariablesMock: + def start(self): + import market + import store + + self.patchers = [ + mock.patch.multiple(market.Portfolio, + data=store.LockedVar(None), + liquidities=store.LockedVar({}), + last_date=store.LockedVar(None), + report=store.LockedVar(store.ReportStore(None, no_http_dup=True)), + worker=None, + worker_tag="", + worker_notify=None, + worker_started=False, + callback=None) + ] + for patcher in self.patchers: + patcher.start() + + def stop(self): + pass + class TimeMock: delta = 0 true_time = time.time + true_sleep = time.sleep time_patch = None datetime_patch = None @@ -202,6 +239,7 @@ class TimeMock: @classmethod def fake_sleep(cls, duration): cls.delta -= duration + cls.true_sleep(0.2) class AcceptanceTestCase(): def parse_file(self, report_file): @@ -236,7 +274,7 @@ class AcceptanceTestCase(): for arg in ["parallel", "report_db"]: if not args.get(arg, False): config.append("--no-{}".format(arg.replace("_", "-"))) - for action in args.get("action", []): + for action in (args.get("action", []) or []): config.extend(["--action", action]) if args.get("report_path") is not None: config.extend(["--report-path", args.get("report_path")]) @@ -267,7 +305,9 @@ class AcceptanceTestCase(): def setUp(self): if not hasattr(self, "files"): - raise "This class expects to be inherited with a class defining self.files in setUp" + raise "This class expects to be inherited with a class defining 'files' variable" + if not hasattr(self, "log_files"): + self.log_files = [] self.reports = {} self.start_date = datetime.datetime.now() @@ -278,33 +318,40 @@ class AcceptanceTestCase(): self.start_date = date self.reports[f] = [user_id, market_id, http_requests, report_lines] - DatabaseMock.start(self.reports, "--no-report-db" not in self.config) - RequestsMock.start(self.requests_by_market()) - FileMock.start() + self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config) + self.requests_mock = RequestsMock(self) + self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self) + self.global_variables_mock = GlobalVariablesMock() + + self.database_mock.start() + self.requests_mock.start() + self.file_mock.start() + self.global_variables_mock.start() TimeMock.start(self.start_date) def base_test(self): - import main main.main(self.config) - RequestsMock.check_calls(self) - DatabaseMock.check_calls(self) - FileMock.check_calls(self) + self.requests_mock.check_calls() + self.database_mock.check_calls() + self.file_mock.check_calls() def tearDown(self): TimeMock.stop() - FileMock.stop() - RequestsMock.stop() - DatabaseMock.stop() + self.global_variables_mock.stop() + self.file_mock.stop() + self.requests_mock.stop() + self.database_mock.stop() -import glob for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True): json_files = glob.glob("{}/*.json".format(dirfile)) + log_files = glob.glob("{}/*.log".format(dirfile)) if len(json_files) > 0: name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1] cname = "".join(list(map(lambda x: x.capitalize(), name.split("_")))) globals()[cname] = type(cname, (AcceptanceTestCase,unittest.TestCase), { + "log_files": log_files, "files": json_files, "test_{}".format(name): AcceptanceTestCase.base_test })