]> git.immae.eu Git - perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git/blobdiff - test_acceptance.py
Add acceptance tests
[perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git] / test_acceptance.py
diff --git a/test_acceptance.py b/test_acceptance.py
new file mode 100644 (file)
index 0000000..88a2dd4
--- /dev/null
@@ -0,0 +1,314 @@
+import requests
+import requests_mock
+import sys, os
+import time, datetime
+import unittest
+from unittest import mock
+from ssl import SSLError
+from decimal import Decimal
+import simplejson as json
+import psycopg2
+from io import StringIO
+
+class FileMock:
+    @classmethod
+    def start(cls):
+        cls.file_mock = mock.patch("market.open")
+        cls.os_mock = mock.patch("os.makedirs")
+        cls.stdout_mock = mock.patch('sys.stdout', new_callable=StringIO)
+        cls.stdout_mock.start()
+        cls.os_mock.start()
+        cls.file_mock.start()
+
+    @classmethod
+    def check_calls(cls, tester):
+        pass
+        #raise NotImplementedError("Todo")
+
+    @classmethod
+    def stop(cls):
+        cls.file_mock.stop()
+        cls.stdout_mock.stop()
+        cls.os_mock.stop()
+
+class DatabaseMock:
+    rows = []
+    report_db = False
+    db_patch = None
+    cursor = None
+    total_report_lines = 0
+    db_mock = None
+    requests = []
+
+    @classmethod
+    def start(cls, reports, report_db):
+        cls.report_db = report_db
+        cls.rows = []
+        cls.total_report_lines=  0
+        for user_id, market_id, http_requests, report_lines in reports.values():
+            cls.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) )
+            cls.total_report_lines += len(report_lines)
+
+        connect_mock = mock.Mock()
+        cls.cursor = mock.MagicMock()
+        connect_mock.cursor.return_value = cls.cursor
+        def _execute(request, *args):
+            cls.requests.append(request)
+        cls.cursor.execute.side_effect = _execute
+        cls.cursor.__iter__.return_value = cls.rows
+
+        cls.db_patch = mock.patch("psycopg2.connect")
+        cls.db_mock = cls.db_patch.start()
+        cls.db_mock.return_value = connect_mock
+
+    @classmethod
+    def check_calls(cls, tester):
+        if cls.report_db:
+            tester.assertEqual(1 + len(cls.rows), cls.db_mock.call_count)
+            tester.assertEqual(1 + len(cls.rows) + cls.total_report_lines, cls.cursor.execute.call_count)
+        else:
+            tester.assertEqual(1, cls.db_mock.call_count)
+            tester.assertEqual(1, cls.cursor.execute.call_count)
+
+    @classmethod
+    def stop(cls):
+        cls.db_patch.stop()
+        cls.db_mock = None
+        cls.cursor = None
+        cls.total_report_lines = 0
+        cls.rows = []
+        cls.report_db = False
+        cls.requests = []
+
+class RequestsMock:
+    adapter = None
+    mocks = {}
+    last_https = {}
+    request_patch = []
+
+    @classmethod
+    def start(cls, reports):
+        cls.adapter = requests_mock.Adapter()
+        cls.adapter.register_uri(requests_mock.ANY, requests_mock.ANY,
+                exc=requests_mock.exceptions.MockException("Not stubbed URL"))
+        true_session = requests.Session
+
+        cls.mocks = {}
+
+        for market_id, elements in reports.items():
+            for element in elements:
+                method = element["method"]
+                url = element["url"]
+                cls.mocks \
+                        .setdefault((method, url), {}) \
+                        .setdefault(market_id, []) \
+                        .append(element)
+
+        for ((method, url), elements) in cls.mocks.items():
+            cls.adapter.register_uri(method, url, text=cls.callback_func(elements), complete_qs=True)
+        def _session():
+            session = true_session()
+            session.get_adapter = lambda url: cls.adapter
+            return session
+        cls.request_patch = [
+                mock.patch.object(requests.sessions, "Session", new=_session),
+                mock.patch.object(requests, "Session", new=_session)
+                ]
+        for patch in cls.request_patch:
+            patch.start()
+
+    @classmethod
+    def stop(cls):
+        for patch in cls.request_patch:
+            patch.stop()
+        cls.request_patch = []
+        cls.last_https = {}
+        cls.mocks = {}
+        cls.adapter = None
+
+    @classmethod
+    def check_calls(cls, tester):
+        for (method, url), elements in cls.mocks.items():
+            for market_id, element in elements.items():
+                tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id))
+
+    @classmethod
+    def clean_body(cls, body):
+        if body is None:
+            return None
+        import re
+        if isinstance(body, bytes):
+            body = body.decode()
+        body = re.sub(r"&nonce=\d*$", "", body)
+        body = re.sub(r"nonce=\d*&?", "", body)
+        return body
+
+    @classmethod
+    def callback_func(cls, elements):
+        def callback(request, context):
+            try:
+                element = elements[request.headers.get("X-market-id")].pop(0)
+            except (IndexError, KeyError):
+                raise RuntimeError("Unexpected call")
+            if element["response"] is None and element["response_same_as"] is not None:
+                element["response"] = cls.last_https[element["response_same_as"]]
+            elif element["response"] is not None:
+                cls.last_https[element["date"]] = element["response"]
+
+            assert cls.clean_body(request.body) == \
+                    cls.clean_body(element["body"]), "Body does not match"
+            context.status_code = element["status"]
+            if "error" in element:
+                if element["error"] == "SSLError":
+                    raise SSLError(element["error_message"])
+                else:
+                    raise getattr(requests.exceptions, element["error"])(element["error_message"])
+            return element["response"]
+        return callback
+
+class TimeMock:
+    delta = 0
+    true_time = time.time
+    time_patch = None
+    datetime_patch = None
+
+    @classmethod
+    def start(cls, start_date):
+        cls.delta = (datetime.datetime.now() - start_date).total_seconds()
+
+        class fake_datetime(datetime.datetime):
+            @classmethod
+            def now(cls, tz=None):
+                if tz is None:
+                    return cls.fromtimestamp(time.time())
+                else:
+                    return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz))
+
+        cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep)
+        cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime)
+        cls.time_patch.start()
+        cls.datetime_patch.start()
+
+    @classmethod
+    def stop(cls):
+        cls.delta = 0
+        cls.datetime_patch.stop()
+        cls.time_patch.stop()
+
+    @classmethod
+    def fake_time(cls):
+        return cls.true_time() - cls.delta
+
+    @classmethod
+    def fake_sleep(cls, duration):
+        cls.delta -= duration
+
+class AcceptanceTestCase():
+    def parse_file(self, report_file):
+        with open(report_file, "rb") as f:
+            json_content = json.load(f, parse_float=Decimal)
+        config, user, date, market_id = self.parse_config(json_content)
+        http_requests = self.parse_requests(json_content)
+
+        return config, user, date, market_id, http_requests, json_content
+
+    def parse_requests(self, json_content):
+        http_requests = []
+        for element in json_content:
+            if element["type"] != "http_request":
+                continue
+            http_requests.append(element)
+        return http_requests
+
+    def parse_config(self, json_content):
+        market_info = None
+        for element in json_content:
+            if element["type"] != "market":
+                continue
+            market_info = element
+        assert market_info is not None, "Couldn't find market element"
+
+        args = market_info["args"]
+        config = []
+        for arg in ["before", "after", "quiet", "debug"]:
+            if args.get(arg, False):
+                config.append("--{}".format(arg))
+        for arg in ["parallel", "report_db"]:
+            if not args.get(arg, False):
+                config.append("--no-{}".format(arg.replace("_", "-")))
+        for action in args.get("action", []):
+            config.extend(["--action", action])
+        if args.get("report_path") is not None:
+            config.extend(["--report-path", args.get("report_path")])
+        if args.get("user") is not None:
+            config.extend(["--user", args.get("user")])
+        config.extend(["--config", ""])
+
+        user = market_info["user_id"]
+        date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f")
+        market_id = market_info["market_id"]
+        return config, user, date, market_id
+
+    def requests_by_market(self):
+        r = {
+                None: []
+                }
+        got_common = False
+        for user_id, market_id, http_requests, report_lines in self.reports.values():
+            r[str(market_id)] = []
+            for http_request in http_requests:
+                if http_request["market_id"] is None:
+                    if not got_common:
+                        r[None].append(http_request)
+                else:
+                    r[str(market_id)].append(http_request)
+            got_common = True
+        return r
+
+    def setUp(self):
+        if not hasattr(self, "files"):
+            raise "This class expects to be inherited with a class defining self.files in setUp"
+
+        self.reports = {}
+        self.start_date = datetime.datetime.now()
+        self.config = []
+        for f in self.files:
+            self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f)
+            if date < self.start_date:
+                self.start_date = date
+            self.reports[f] = [user_id, market_id, http_requests, report_lines]
+
+        DatabaseMock.start(self.reports, "--no-report-db" not in self.config)
+        RequestsMock.start(self.requests_by_market())
+        FileMock.start()
+        TimeMock.start(self.start_date)
+
+    def base_test(self):
+        import main
+        main.main(self.config)
+        RequestsMock.check_calls(self)
+        DatabaseMock.check_calls(self)
+        FileMock.check_calls(self)
+
+    def tearDown(self):
+        TimeMock.stop()
+        FileMock.stop()
+        RequestsMock.stop()
+        DatabaseMock.stop()
+
+import glob
+for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True):
+    json_files = glob.glob("{}/*.json".format(dirfile))
+    if len(json_files) > 0:
+        name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1]
+        cname = "".join(list(map(lambda x: x.capitalize(), name.split("_"))))
+
+        globals()[cname] = type(cname,
+                (AcceptanceTestCase,unittest.TestCase), {
+            "files": json_files,
+            "test_{}".format(name): AcceptanceTestCase.base_test
+            })
+
+if __name__ == '__main__':
+    unittest.main()
+