]> git.immae.eu Git - perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git/blobdiff - tests/acceptance.py
Move acceptance tests to common directory
[perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git] / tests / acceptance.py
diff --git a/tests/acceptance.py b/tests/acceptance.py
new file mode 100644 (file)
index 0000000..66014ca
--- /dev/null
@@ -0,0 +1,358 @@
+import requests
+import requests_mock
+import sys, os
+import time, datetime
+from unittest import mock
+from ssl import SSLError
+from decimal import Decimal
+import simplejson as json
+import psycopg2
+from io import StringIO
+import re
+import functools
+import threading
+
+class TimeMock:
+    delta = {}
+    delta_init = 0
+    true_time = time.time
+    true_sleep = time.sleep
+    time_patch = None
+    datetime_patch = None
+
+    @classmethod
+    def travel(cls, start_date):
+        cls.delta = {}
+        cls.delta_init = (datetime.datetime.now() - start_date).total_seconds()
+
+    @classmethod
+    def start(cls):
+        cls.delta = {}
+        cls.delta_init = 0
+
+        class fake_datetime(datetime.datetime):
+            @classmethod
+            def now(cls, tz=None):
+                if tz is None:
+                    return cls.fromtimestamp(time.time())
+                else:
+                    return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz))
+
+        cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep)
+        cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime)
+        cls.time_patch.start()
+        cls.datetime_patch.start()
+
+    @classmethod
+    def stop(cls):
+        cls.delta = {}
+        cls.delta_init = 0
+
+    @classmethod
+    def fake_time(cls):
+        cls.delta.setdefault(threading.current_thread(), cls.delta_init)
+        return cls.true_time() - cls.delta[threading.current_thread()]
+
+    @classmethod
+    def fake_sleep(cls, duration):
+        cls.delta.setdefault(threading.current_thread(), cls.delta_init)
+        cls.delta[threading.current_thread()] -= float(duration)
+        cls.true_sleep(min(float(duration), 0.1))
+
+TimeMock.start()
+import main
+
+class FileMock:
+    def __init__(self, log_files, quiet, tester):
+        self.tester = tester
+        self.log_files = []
+        if log_files is not None and len(log_files) > 0:
+            self.read_log_files(log_files)
+        self.quiet = quiet
+        self.patches = [
+                mock.patch("market.open"),
+                mock.patch("os.makedirs"),
+                mock.patch("sys.stdout", new_callable=StringIO),
+                ]
+        self.mocks = []
+
+    def start(self):
+        for patch in self.patches:
+            self.mocks.append(patch.start())
+        self.stdout = self.mocks[-1]
+
+    def check_calls(self):
+        stdout = self.stdout.getvalue()
+        if self.quiet:
+            self.tester.assertEqual("", stdout)
+        else:
+            log = self.strip_log(stdout)
+            if len(self.log_files) != 0:
+                split_logs = log.split("\n")
+                self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs))
+                try:
+                    for log_file in self.log_files:
+                        for line in log_file:
+                            split_logs.pop(split_logs.index(line))
+                except ValueError:
+                    if not line.startswith("[Worker] "):
+                        self.tester.fail("« {} » not found in log file {}".format(line, split_logs))
+    # Le fichier de log est écrit
+    # Le rapport est écrit si pertinent
+    # Le rapport contient le bon nombre de lignes
+
+    def stop(self):
+        for patch in self.patches[::-1]:
+            patch.stop()
+            self.mocks.pop()
+
+    def strip_log(self, log):
+        log = log.replace("\n\n", "\n")
+        return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE)
+
+    def read_log_files(self, log_files):
+        for log_file in log_files:
+            with open(log_file, "r") as f:
+                log = self.strip_log(f.read()).split("\n")
+                if len(log[-1]) == 0:
+                    log.pop()
+                self.log_files.append(log)
+
+class DatabaseMock:
+    def __init__(self, tester, reports, report_db):
+        self.tester = tester
+        self.reports = reports
+        self.report_db = report_db
+        self.rows = []
+        self.total_report_lines=  0
+        self.requests = []
+        for user_id, market_id, http_requests, report_lines in self.reports.values():
+            self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) )
+            self.total_report_lines += len(report_lines)
+
+    def start(self):
+        connect_mock = mock.Mock()
+        self.cursor = mock.MagicMock()
+        connect_mock.cursor.return_value = self.cursor
+        def _execute(request, *args):
+            self.requests.append(request)
+        self.cursor.execute.side_effect = _execute
+        self.cursor.__iter__.return_value = self.rows
+
+        self.db_patch = mock.patch("psycopg2.connect")
+        self.db_mock = self.db_patch.start()
+        self.db_mock.return_value = connect_mock
+
+    def check_calls(self):
+        if self.report_db:
+            self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count)
+            self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count)
+        else:
+            self.tester.assertEqual(1, self.db_mock.call_count)
+            self.tester.assertEqual(1, self.cursor.execute.call_count)
+
+    def stop(self):
+        self.db_patch.stop()
+
+class RequestsMock:
+    def __init__(self, tester):
+        self.tester = tester
+        self.reports = tester.requests_by_market()
+
+        self.last_https = {}
+        self.error_calls = []
+        self.mocker = requests_mock.Mocker()
+        def not_stubbed(*args):
+            self.error_calls.append([args[0].method, args[0].url])
+            raise requests_mock.exceptions.MockException("Not stubbed URL")
+        self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY,
+                text=not_stubbed)
+
+        self.mocks = {}
+
+        for market_id, elements in self.reports.items():
+            for element in elements:
+                method = element["method"]
+                url = element["url"]
+                self.mocks \
+                        .setdefault((method, url), {}) \
+                        .setdefault(market_id, []) \
+                        .append(element)
+
+        for ((method, url), elements) in self.mocks.items():
+            self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True)
+
+    def start(self):
+        self.mocker.start()
+
+    def stop(self):
+        self.mocker.stop()
+
+    lazy_calls = [
+            "https://cryptoportfolio.io/wp-content/uploads/portfolio/json/cryptoportfolio.json",
+            "https://poloniex.com/public?command=returnTicker",
+            ]
+    def check_calls(self):
+        self.tester.assertEqual([], self.error_calls)
+        for (method, url), elements in self.mocks.items():
+            for market_id, element in elements.items():
+                if url not in self.lazy_calls:
+                    self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id))
+
+    def clean_body(self, body):
+        if body is None:
+            return None
+        if isinstance(body, bytes):
+            body = body.decode()
+        body = re.sub(r"&nonce=\d*$", "", body)
+        body = re.sub(r"nonce=\d*&?", "", body)
+        return body
+
+def callback(self, elements, request, context):
+    try:
+        element = elements[request.headers.get("X-market-id")].pop(0)
+    except (IndexError, KeyError):
+        self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")])
+        raise RuntimeError("Unexpected call")
+    if element["response"] is None and element["response_same_as"] is not None:
+        element["response"] = self.last_https[element["response_same_as"]]
+    elif element["response"] is not None:
+        self.last_https[element["date"]] = element["response"]
+
+    time.sleep(element.get("duration", 0))
+
+    assert self.clean_body(request.body) == \
+            self.clean_body(element["body"]), "Body does not match"
+    context.status_code = element["status"]
+    if "error" in element:
+        if element["error"] == "SSLError":
+            raise SSLError(element["error_message"])
+        else:
+            raise getattr(requests.exceptions, element["error"])(element["error_message"])
+    return element["response"]
+
+class GlobalVariablesMock:
+    def start(self):
+        import market
+        import store
+
+        self.patchers = [
+                mock.patch.multiple(market.Portfolio,
+                    data=store.LockedVar(None),
+                    liquidities=store.LockedVar({}),
+                    last_date=store.LockedVar(None),
+                    report=store.LockedVar(store.ReportStore(None, no_http_dup=True)),
+                    worker=None,
+                    worker_tag="",
+                    worker_notify=None,
+                    worker_started=False,
+                    callback=None)
+                ]
+        for patcher in self.patchers:
+            patcher.start()
+
+    def stop(self):
+        pass
+
+
+class AcceptanceTestCase():
+    def parse_file(self, report_file):
+        with open(report_file, "rb") as f:
+            json_content = json.load(f, parse_float=Decimal)
+        config, user, date, market_id = self.parse_config(json_content)
+        http_requests = self.parse_requests(json_content)
+
+        return config, user, date, market_id, http_requests, json_content
+
+    def parse_requests(self, json_content):
+        http_requests = []
+        for element in json_content:
+            if element["type"] != "http_request":
+                continue
+            http_requests.append(element)
+        return http_requests
+
+    def parse_config(self, json_content):
+        market_info = None
+        for element in json_content:
+            if element["type"] != "market":
+                continue
+            market_info = element
+        assert market_info is not None, "Couldn't find market element"
+
+        args = market_info["args"]
+        config = []
+        for arg in ["before", "after", "quiet", "debug"]:
+            if args.get(arg, False):
+                config.append("--{}".format(arg))
+        for arg in ["parallel", "report_db"]:
+            if not args.get(arg, False):
+                config.append("--no-{}".format(arg.replace("_", "-")))
+        for action in (args.get("action", []) or []):
+            config.extend(["--action", action])
+        if args.get("report_path") is not None:
+            config.extend(["--report-path", args.get("report_path")])
+        if args.get("user") is not None:
+            config.extend(["--user", args.get("user")])
+        config.extend(["--config", ""])
+
+        user = market_info["user_id"]
+        date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f")
+        market_id = market_info["market_id"]
+        return config, user, date, market_id
+
+    def requests_by_market(self):
+        r = {
+                None: []
+                }
+        got_common = False
+        for user_id, market_id, http_requests, report_lines in self.reports.values():
+            r[str(market_id)] = []
+            for http_request in http_requests:
+                if http_request["market_id"] is None:
+                    if not got_common:
+                        r[None].append(http_request)
+                else:
+                    r[str(market_id)].append(http_request)
+            got_common = True
+        return r
+
+    def setUp(self):
+        if not hasattr(self, "files"):
+            raise "This class expects to be inherited with a class defining 'files' variable"
+        if not hasattr(self, "log_files"):
+            self.log_files = []
+
+        self.reports = {}
+        self.start_date = datetime.datetime.now()
+        self.config = []
+        for f in self.files:
+            self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f)
+            if date < self.start_date:
+                self.start_date = date
+            self.reports[f] = [user_id, market_id, http_requests, report_lines]
+
+        self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config)
+        self.requests_mock = RequestsMock(self)
+        self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self)
+        self.global_variables_mock = GlobalVariablesMock()
+
+        self.database_mock.start()
+        self.requests_mock.start()
+        self.file_mock.start()
+        self.global_variables_mock.start()
+        TimeMock.travel(self.start_date)
+
+    def base_test(self):
+        main.main(self.config)
+        self.requests_mock.check_calls()
+        self.database_mock.check_calls()
+        self.file_mock.check_calls()
+
+    def tearDown(self):
+        TimeMock.stop()
+        self.global_variables_mock.stop()
+        self.file_mock.stop()
+        self.requests_mock.stop()
+        self.database_mock.stop()
+