]> git.immae.eu Git - perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git/blobdiff - test_acceptance.py
Add some acceptance tests
[perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git] / test_acceptance.py
index 88a2dd4542b064c543efc813efa2d1d6b40b85dd..3633928c455b5cec8e8cbbeae35b68a44ebd048d 100644 (file)
@@ -9,166 +9,203 @@ from decimal import Decimal
 import simplejson as json
 import psycopg2
 from io import StringIO
+import re
+import functools
+import glob
+
+import main
 
 class FileMock:
-    @classmethod
-    def start(cls):
-        cls.file_mock = mock.patch("market.open")
-        cls.os_mock = mock.patch("os.makedirs")
-        cls.stdout_mock = mock.patch('sys.stdout', new_callable=StringIO)
-        cls.stdout_mock.start()
-        cls.os_mock.start()
-        cls.file_mock.start()
+    def __init__(self, log_files, quiet, tester):
+        self.tester = tester
+        self.log_files = []
+        if log_files is not None and len(log_files) > 0:
+            self.read_log_files(log_files)
+        self.quiet = quiet
+        self.patches = [
+                mock.patch("market.open"),
+                mock.patch("os.makedirs"),
+                mock.patch("sys.stdout", new_callable=StringIO),
+                ]
+        self.mocks = []
 
-    @classmethod
-    def check_calls(cls, tester):
-        pass
-        #raise NotImplementedError("Todo")
+    def start(self):
+        for patch in self.patches:
+            self.mocks.append(patch.start())
+        self.stdout = self.mocks[-1]
 
-    @classmethod
-    def stop(cls):
-        cls.file_mock.stop()
-        cls.stdout_mock.stop()
-        cls.os_mock.stop()
+    def check_calls(self):
+        stdout = self.stdout.getvalue()
+        if self.quiet:
+            self.tester.assertEqual("", stdout)
+        else:
+            log = self.strip_log(stdout)
+            if len(self.log_files) != 0:
+                split_logs = log.split("\n")
+                self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs))
+                try:
+                    for log_file in self.log_files:
+                        for line in log_file:
+                            split_logs.pop(split_logs.index(line))
+                except ValueError:
+                    if not line.startswith("[Worker] "):
+                        self.tester.fail("« {} » not found in log file {}".format(line, split_logs))
+    # Le fichier de log est écrit
+    # Le fichier de log est printed uniquement si non quiet
+    # Le rapport est écrit si pertinent
+    # Le rapport contient le bon nombre de lignes
+
+    def stop(self):
+        for patch in self.patches[::-1]:
+            patch.stop()
+            self.mocks.pop()
 
-class DatabaseMock:
-    rows = []
-    report_db = False
-    db_patch = None
-    cursor = None
-    total_report_lines = 0
-    db_mock = None
-    requests = []
+    def strip_log(self, log):
+        log = log.replace("\n\n", "\n")
+        return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE)
 
-    @classmethod
-    def start(cls, reports, report_db):
-        cls.report_db = report_db
-        cls.rows = []
-        cls.total_report_lines=  0
-        for user_id, market_id, http_requests, report_lines in reports.values():
-            cls.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) )
-            cls.total_report_lines += len(report_lines)
+    def read_log_files(self, log_files):
+        for log_file in log_files:
+            with open(log_file, "r") as f:
+                log = self.strip_log(f.read()).split("\n")
+                if len(log[-1]) == 0:
+                    log.pop()
+                self.log_files.append(log)
+
+class DatabaseMock:
+    def __init__(self, tester, reports, report_db):
+        self.tester = tester
+        self.reports = reports
+        self.report_db = report_db
+        self.rows = []
+        self.total_report_lines=  0
+        self.requests = []
+        for user_id, market_id, http_requests, report_lines in self.reports.values():
+            self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) )
+            self.total_report_lines += len(report_lines)
 
+    def start(self):
         connect_mock = mock.Mock()
-        cls.cursor = mock.MagicMock()
-        connect_mock.cursor.return_value = cls.cursor
+        self.cursor = mock.MagicMock()
+        connect_mock.cursor.return_value = self.cursor
         def _execute(request, *args):
-            cls.requests.append(request)
-        cls.cursor.execute.side_effect = _execute
-        cls.cursor.__iter__.return_value = cls.rows
-
-        cls.db_patch = mock.patch("psycopg2.connect")
-        cls.db_mock = cls.db_patch.start()
-        cls.db_mock.return_value = connect_mock
-
-    @classmethod
-    def check_calls(cls, tester):
-        if cls.report_db:
-            tester.assertEqual(1 + len(cls.rows), cls.db_mock.call_count)
-            tester.assertEqual(1 + len(cls.rows) + cls.total_report_lines, cls.cursor.execute.call_count)
+            self.requests.append(request)
+        self.cursor.execute.side_effect = _execute
+        self.cursor.__iter__.return_value = self.rows
+
+        self.db_patch = mock.patch("psycopg2.connect")
+        self.db_mock = self.db_patch.start()
+        self.db_mock.return_value = connect_mock
+
+    def check_calls(self):
+        if self.report_db:
+            self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count)
+            self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count)
         else:
-            tester.assertEqual(1, cls.db_mock.call_count)
-            tester.assertEqual(1, cls.cursor.execute.call_count)
+            self.tester.assertEqual(1, self.db_mock.call_count)
+            self.tester.assertEqual(1, self.cursor.execute.call_count)
 
-    @classmethod
-    def stop(cls):
-        cls.db_patch.stop()
-        cls.db_mock = None
-        cls.cursor = None
-        cls.total_report_lines = 0
-        cls.rows = []
-        cls.report_db = False
-        cls.requests = []
+    def stop(self):
+        self.db_patch.stop()
 
 class RequestsMock:
-    adapter = None
-    mocks = {}
-    last_https = {}
-    request_patch = []
-
-    @classmethod
-    def start(cls, reports):
-        cls.adapter = requests_mock.Adapter()
-        cls.adapter.register_uri(requests_mock.ANY, requests_mock.ANY,
-                exc=requests_mock.exceptions.MockException("Not stubbed URL"))
-        true_session = requests.Session
-
-        cls.mocks = {}
-
-        for market_id, elements in reports.items():
+    def __init__(self, tester):
+        self.tester = tester
+        self.reports = tester.requests_by_market()
+
+        self.last_https = {}
+        self.error_calls = []
+        self.mocker = requests_mock.Mocker()
+        def not_stubbed(*args):
+            self.error_calls.append([args[0].method, args[0].url])
+            raise requests_mock.exceptions.MockException("Not stubbed URL")
+        self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY,
+                text=not_stubbed)
+
+        self.mocks = {}
+
+        for market_id, elements in self.reports.items():
             for element in elements:
                 method = element["method"]
                 url = element["url"]
-                cls.mocks \
+                self.mocks \
                         .setdefault((method, url), {}) \
                         .setdefault(market_id, []) \
                         .append(element)
 
-        for ((method, url), elements) in cls.mocks.items():
-            cls.adapter.register_uri(method, url, text=cls.callback_func(elements), complete_qs=True)
-        def _session():
-            session = true_session()
-            session.get_adapter = lambda url: cls.adapter
-            return session
-        cls.request_patch = [
-                mock.patch.object(requests.sessions, "Session", new=_session),
-                mock.patch.object(requests, "Session", new=_session)
-                ]
-        for patch in cls.request_patch:
-            patch.start()
+        for ((method, url), elements) in self.mocks.items():
+            self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True)
 
-    @classmethod
-    def stop(cls):
-        for patch in cls.request_patch:
-            patch.stop()
-        cls.request_patch = []
-        cls.last_https = {}
-        cls.mocks = {}
-        cls.adapter = None
+    def start(self):
+        self.mocker.start()
 
-    @classmethod
-    def check_calls(cls, tester):
-        for (method, url), elements in cls.mocks.items():
+    def stop(self):
+        self.mocker.stop()
+
+    def check_calls(self):
+        self.tester.assertEqual([], self.error_calls)
+        for (method, url), elements in self.mocks.items():
             for market_id, element in elements.items():
-                tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id))
+                self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id))
 
-    @classmethod
-    def clean_body(cls, body):
+    def clean_body(self, body):
         if body is None:
             return None
-        import re
         if isinstance(body, bytes):
             body = body.decode()
         body = re.sub(r"&nonce=\d*$", "", body)
         body = re.sub(r"nonce=\d*&?", "", body)
         return body
 
-    @classmethod
-    def callback_func(cls, elements):
-        def callback(request, context):
-            try:
-                element = elements[request.headers.get("X-market-id")].pop(0)
-            except (IndexError, KeyError):
-                raise RuntimeError("Unexpected call")
-            if element["response"] is None and element["response_same_as"] is not None:
-                element["response"] = cls.last_https[element["response_same_as"]]
-            elif element["response"] is not None:
-                cls.last_https[element["date"]] = element["response"]
-
-            assert cls.clean_body(request.body) == \
-                    cls.clean_body(element["body"]), "Body does not match"
-            context.status_code = element["status"]
-            if "error" in element:
-                if element["error"] == "SSLError":
-                    raise SSLError(element["error_message"])
-                else:
-                    raise getattr(requests.exceptions, element["error"])(element["error_message"])
-            return element["response"]
-        return callback
+def callback(self, elements, request, context):
+    try:
+        element = elements[request.headers.get("X-market-id")].pop(0)
+    except (IndexError, KeyError):
+        self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")])
+        raise RuntimeError("Unexpected call")
+    if element["response"] is None and element["response_same_as"] is not None:
+        element["response"] = self.last_https[element["response_same_as"]]
+    elif element["response"] is not None:
+        self.last_https[element["date"]] = element["response"]
+
+    assert self.clean_body(request.body) == \
+            self.clean_body(element["body"]), "Body does not match"
+    context.status_code = element["status"]
+    if "error" in element:
+        if element["error"] == "SSLError":
+            raise SSLError(element["error_message"])
+        else:
+            raise getattr(requests.exceptions, element["error"])(element["error_message"])
+    return element["response"]
+
+class GlobalVariablesMock:
+    def start(self):
+        import market
+        import store
+
+        self.patchers = [
+                mock.patch.multiple(market.Portfolio,
+                    data=store.LockedVar(None),
+                    liquidities=store.LockedVar({}),
+                    last_date=store.LockedVar(None),
+                    report=store.LockedVar(store.ReportStore(None, no_http_dup=True)),
+                    worker=None,
+                    worker_tag="",
+                    worker_notify=None,
+                    worker_started=False,
+                    callback=None)
+                ]
+        for patcher in self.patchers:
+            patcher.start()
+
+    def stop(self):
+        pass
+
 
 class TimeMock:
     delta = 0
     true_time = time.time
+    true_sleep = time.sleep
     time_patch = None
     datetime_patch = None
 
@@ -202,6 +239,7 @@ class TimeMock:
     @classmethod
     def fake_sleep(cls, duration):
         cls.delta -= duration
+        cls.true_sleep(0.2)
 
 class AcceptanceTestCase():
     def parse_file(self, report_file):
@@ -236,7 +274,7 @@ class AcceptanceTestCase():
         for arg in ["parallel", "report_db"]:
             if not args.get(arg, False):
                 config.append("--no-{}".format(arg.replace("_", "-")))
-        for action in args.get("action", []):
+        for action in (args.get("action", []) or []):
             config.extend(["--action", action])
         if args.get("report_path") is not None:
             config.extend(["--report-path", args.get("report_path")])
@@ -267,7 +305,9 @@ class AcceptanceTestCase():
 
     def setUp(self):
         if not hasattr(self, "files"):
-            raise "This class expects to be inherited with a class defining self.files in setUp"
+            raise "This class expects to be inherited with a class defining 'files' variable"
+        if not hasattr(self, "log_files"):
+            self.log_files = []
 
         self.reports = {}
         self.start_date = datetime.datetime.now()
@@ -278,33 +318,40 @@ class AcceptanceTestCase():
                 self.start_date = date
             self.reports[f] = [user_id, market_id, http_requests, report_lines]
 
-        DatabaseMock.start(self.reports, "--no-report-db" not in self.config)
-        RequestsMock.start(self.requests_by_market())
-        FileMock.start()
+        self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config)
+        self.requests_mock = RequestsMock(self)
+        self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self)
+        self.global_variables_mock = GlobalVariablesMock()
+
+        self.database_mock.start()
+        self.requests_mock.start()
+        self.file_mock.start()
+        self.global_variables_mock.start()
         TimeMock.start(self.start_date)
 
     def base_test(self):
-        import main
         main.main(self.config)
-        RequestsMock.check_calls(self)
-        DatabaseMock.check_calls(self)
-        FileMock.check_calls(self)
+        self.requests_mock.check_calls()
+        self.database_mock.check_calls()
+        self.file_mock.check_calls()
 
     def tearDown(self):
         TimeMock.stop()
-        FileMock.stop()
-        RequestsMock.stop()
-        DatabaseMock.stop()
+        self.global_variables_mock.stop()
+        self.file_mock.stop()
+        self.requests_mock.stop()
+        self.database_mock.stop()
 
-import glob
 for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True):
     json_files = glob.glob("{}/*.json".format(dirfile))
+    log_files = glob.glob("{}/*.log".format(dirfile))
     if len(json_files) > 0:
         name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1]
         cname = "".join(list(map(lambda x: x.capitalize(), name.split("_"))))
 
         globals()[cname] = type(cname,
                 (AcceptanceTestCase,unittest.TestCase), {
+            "log_files": log_files,
             "files": json_files,
             "test_{}".format(name): AcceptanceTestCase.base_test
             })