]> git.immae.eu Git - perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git/blobdiff - test_acceptance.py
Move acceptance tests to common directory
[perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git] / test_acceptance.py
diff --git a/test_acceptance.py b/test_acceptance.py
deleted file mode 100644 (file)
index 3633928..0000000
+++ /dev/null
@@ -1,361 +0,0 @@
-import requests
-import requests_mock
-import sys, os
-import time, datetime
-import unittest
-from unittest import mock
-from ssl import SSLError
-from decimal import Decimal
-import simplejson as json
-import psycopg2
-from io import StringIO
-import re
-import functools
-import glob
-
-import main
-
-class FileMock:
-    def __init__(self, log_files, quiet, tester):
-        self.tester = tester
-        self.log_files = []
-        if log_files is not None and len(log_files) > 0:
-            self.read_log_files(log_files)
-        self.quiet = quiet
-        self.patches = [
-                mock.patch("market.open"),
-                mock.patch("os.makedirs"),
-                mock.patch("sys.stdout", new_callable=StringIO),
-                ]
-        self.mocks = []
-
-    def start(self):
-        for patch in self.patches:
-            self.mocks.append(patch.start())
-        self.stdout = self.mocks[-1]
-
-    def check_calls(self):
-        stdout = self.stdout.getvalue()
-        if self.quiet:
-            self.tester.assertEqual("", stdout)
-        else:
-            log = self.strip_log(stdout)
-            if len(self.log_files) != 0:
-                split_logs = log.split("\n")
-                self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs))
-                try:
-                    for log_file in self.log_files:
-                        for line in log_file:
-                            split_logs.pop(split_logs.index(line))
-                except ValueError:
-                    if not line.startswith("[Worker] "):
-                        self.tester.fail("« {} » not found in log file {}".format(line, split_logs))
-    # Le fichier de log est écrit
-    # Le fichier de log est printed uniquement si non quiet
-    # Le rapport est écrit si pertinent
-    # Le rapport contient le bon nombre de lignes
-
-    def stop(self):
-        for patch in self.patches[::-1]:
-            patch.stop()
-            self.mocks.pop()
-
-    def strip_log(self, log):
-        log = log.replace("\n\n", "\n")
-        return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE)
-
-    def read_log_files(self, log_files):
-        for log_file in log_files:
-            with open(log_file, "r") as f:
-                log = self.strip_log(f.read()).split("\n")
-                if len(log[-1]) == 0:
-                    log.pop()
-                self.log_files.append(log)
-
-class DatabaseMock:
-    def __init__(self, tester, reports, report_db):
-        self.tester = tester
-        self.reports = reports
-        self.report_db = report_db
-        self.rows = []
-        self.total_report_lines=  0
-        self.requests = []
-        for user_id, market_id, http_requests, report_lines in self.reports.values():
-            self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) )
-            self.total_report_lines += len(report_lines)
-
-    def start(self):
-        connect_mock = mock.Mock()
-        self.cursor = mock.MagicMock()
-        connect_mock.cursor.return_value = self.cursor
-        def _execute(request, *args):
-            self.requests.append(request)
-        self.cursor.execute.side_effect = _execute
-        self.cursor.__iter__.return_value = self.rows
-
-        self.db_patch = mock.patch("psycopg2.connect")
-        self.db_mock = self.db_patch.start()
-        self.db_mock.return_value = connect_mock
-
-    def check_calls(self):
-        if self.report_db:
-            self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count)
-            self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count)
-        else:
-            self.tester.assertEqual(1, self.db_mock.call_count)
-            self.tester.assertEqual(1, self.cursor.execute.call_count)
-
-    def stop(self):
-        self.db_patch.stop()
-
-class RequestsMock:
-    def __init__(self, tester):
-        self.tester = tester
-        self.reports = tester.requests_by_market()
-
-        self.last_https = {}
-        self.error_calls = []
-        self.mocker = requests_mock.Mocker()
-        def not_stubbed(*args):
-            self.error_calls.append([args[0].method, args[0].url])
-            raise requests_mock.exceptions.MockException("Not stubbed URL")
-        self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY,
-                text=not_stubbed)
-
-        self.mocks = {}
-
-        for market_id, elements in self.reports.items():
-            for element in elements:
-                method = element["method"]
-                url = element["url"]
-                self.mocks \
-                        .setdefault((method, url), {}) \
-                        .setdefault(market_id, []) \
-                        .append(element)
-
-        for ((method, url), elements) in self.mocks.items():
-            self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True)
-
-    def start(self):
-        self.mocker.start()
-
-    def stop(self):
-        self.mocker.stop()
-
-    def check_calls(self):
-        self.tester.assertEqual([], self.error_calls)
-        for (method, url), elements in self.mocks.items():
-            for market_id, element in elements.items():
-                self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id))
-
-    def clean_body(self, body):
-        if body is None:
-            return None
-        if isinstance(body, bytes):
-            body = body.decode()
-        body = re.sub(r"&nonce=\d*$", "", body)
-        body = re.sub(r"nonce=\d*&?", "", body)
-        return body
-
-def callback(self, elements, request, context):
-    try:
-        element = elements[request.headers.get("X-market-id")].pop(0)
-    except (IndexError, KeyError):
-        self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")])
-        raise RuntimeError("Unexpected call")
-    if element["response"] is None and element["response_same_as"] is not None:
-        element["response"] = self.last_https[element["response_same_as"]]
-    elif element["response"] is not None:
-        self.last_https[element["date"]] = element["response"]
-
-    assert self.clean_body(request.body) == \
-            self.clean_body(element["body"]), "Body does not match"
-    context.status_code = element["status"]
-    if "error" in element:
-        if element["error"] == "SSLError":
-            raise SSLError(element["error_message"])
-        else:
-            raise getattr(requests.exceptions, element["error"])(element["error_message"])
-    return element["response"]
-
-class GlobalVariablesMock:
-    def start(self):
-        import market
-        import store
-
-        self.patchers = [
-                mock.patch.multiple(market.Portfolio,
-                    data=store.LockedVar(None),
-                    liquidities=store.LockedVar({}),
-                    last_date=store.LockedVar(None),
-                    report=store.LockedVar(store.ReportStore(None, no_http_dup=True)),
-                    worker=None,
-                    worker_tag="",
-                    worker_notify=None,
-                    worker_started=False,
-                    callback=None)
-                ]
-        for patcher in self.patchers:
-            patcher.start()
-
-    def stop(self):
-        pass
-
-
-class TimeMock:
-    delta = 0
-    true_time = time.time
-    true_sleep = time.sleep
-    time_patch = None
-    datetime_patch = None
-
-    @classmethod
-    def start(cls, start_date):
-        cls.delta = (datetime.datetime.now() - start_date).total_seconds()
-
-        class fake_datetime(datetime.datetime):
-            @classmethod
-            def now(cls, tz=None):
-                if tz is None:
-                    return cls.fromtimestamp(time.time())
-                else:
-                    return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz))
-
-        cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep)
-        cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime)
-        cls.time_patch.start()
-        cls.datetime_patch.start()
-
-    @classmethod
-    def stop(cls):
-        cls.delta = 0
-        cls.datetime_patch.stop()
-        cls.time_patch.stop()
-
-    @classmethod
-    def fake_time(cls):
-        return cls.true_time() - cls.delta
-
-    @classmethod
-    def fake_sleep(cls, duration):
-        cls.delta -= duration
-        cls.true_sleep(0.2)
-
-class AcceptanceTestCase():
-    def parse_file(self, report_file):
-        with open(report_file, "rb") as f:
-            json_content = json.load(f, parse_float=Decimal)
-        config, user, date, market_id = self.parse_config(json_content)
-        http_requests = self.parse_requests(json_content)
-
-        return config, user, date, market_id, http_requests, json_content
-
-    def parse_requests(self, json_content):
-        http_requests = []
-        for element in json_content:
-            if element["type"] != "http_request":
-                continue
-            http_requests.append(element)
-        return http_requests
-
-    def parse_config(self, json_content):
-        market_info = None
-        for element in json_content:
-            if element["type"] != "market":
-                continue
-            market_info = element
-        assert market_info is not None, "Couldn't find market element"
-
-        args = market_info["args"]
-        config = []
-        for arg in ["before", "after", "quiet", "debug"]:
-            if args.get(arg, False):
-                config.append("--{}".format(arg))
-        for arg in ["parallel", "report_db"]:
-            if not args.get(arg, False):
-                config.append("--no-{}".format(arg.replace("_", "-")))
-        for action in (args.get("action", []) or []):
-            config.extend(["--action", action])
-        if args.get("report_path") is not None:
-            config.extend(["--report-path", args.get("report_path")])
-        if args.get("user") is not None:
-            config.extend(["--user", args.get("user")])
-        config.extend(["--config", ""])
-
-        user = market_info["user_id"]
-        date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f")
-        market_id = market_info["market_id"]
-        return config, user, date, market_id
-
-    def requests_by_market(self):
-        r = {
-                None: []
-                }
-        got_common = False
-        for user_id, market_id, http_requests, report_lines in self.reports.values():
-            r[str(market_id)] = []
-            for http_request in http_requests:
-                if http_request["market_id"] is None:
-                    if not got_common:
-                        r[None].append(http_request)
-                else:
-                    r[str(market_id)].append(http_request)
-            got_common = True
-        return r
-
-    def setUp(self):
-        if not hasattr(self, "files"):
-            raise "This class expects to be inherited with a class defining 'files' variable"
-        if not hasattr(self, "log_files"):
-            self.log_files = []
-
-        self.reports = {}
-        self.start_date = datetime.datetime.now()
-        self.config = []
-        for f in self.files:
-            self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f)
-            if date < self.start_date:
-                self.start_date = date
-            self.reports[f] = [user_id, market_id, http_requests, report_lines]
-
-        self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config)
-        self.requests_mock = RequestsMock(self)
-        self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self)
-        self.global_variables_mock = GlobalVariablesMock()
-
-        self.database_mock.start()
-        self.requests_mock.start()
-        self.file_mock.start()
-        self.global_variables_mock.start()
-        TimeMock.start(self.start_date)
-
-    def base_test(self):
-        main.main(self.config)
-        self.requests_mock.check_calls()
-        self.database_mock.check_calls()
-        self.file_mock.check_calls()
-
-    def tearDown(self):
-        TimeMock.stop()
-        self.global_variables_mock.stop()
-        self.file_mock.stop()
-        self.requests_mock.stop()
-        self.database_mock.stop()
-
-for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True):
-    json_files = glob.glob("{}/*.json".format(dirfile))
-    log_files = glob.glob("{}/*.log".format(dirfile))
-    if len(json_files) > 0:
-        name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1]
-        cname = "".join(list(map(lambda x: x.capitalize(), name.split("_"))))
-
-        globals()[cname] = type(cname,
-                (AcceptanceTestCase,unittest.TestCase), {
-            "log_files": log_files,
-            "files": json_files,
-            "test_{}".format(name): AcceptanceTestCase.base_test
-            })
-
-if __name__ == '__main__':
-    unittest.main()
-