+++ /dev/null
-import requests
-import requests_mock
-import sys, os
-import time, datetime
-import unittest
-from unittest import mock
-from ssl import SSLError
-from decimal import Decimal
-import simplejson as json
-import psycopg2
-from io import StringIO
-import re
-import functools
-import glob
-
-import main
-
-class FileMock:
- def __init__(self, log_files, quiet, tester):
- self.tester = tester
- self.log_files = []
- if log_files is not None and len(log_files) > 0:
- self.read_log_files(log_files)
- self.quiet = quiet
- self.patches = [
- mock.patch("market.open"),
- mock.patch("os.makedirs"),
- mock.patch("sys.stdout", new_callable=StringIO),
- ]
- self.mocks = []
-
- def start(self):
- for patch in self.patches:
- self.mocks.append(patch.start())
- self.stdout = self.mocks[-1]
-
- def check_calls(self):
- stdout = self.stdout.getvalue()
- if self.quiet:
- self.tester.assertEqual("", stdout)
- else:
- log = self.strip_log(stdout)
- if len(self.log_files) != 0:
- split_logs = log.split("\n")
- self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs))
- try:
- for log_file in self.log_files:
- for line in log_file:
- split_logs.pop(split_logs.index(line))
- except ValueError:
- if not line.startswith("[Worker] "):
- self.tester.fail("« {} » not found in log file {}".format(line, split_logs))
- # Le fichier de log est écrit
- # Le fichier de log est printed uniquement si non quiet
- # Le rapport est écrit si pertinent
- # Le rapport contient le bon nombre de lignes
-
- def stop(self):
- for patch in self.patches[::-1]:
- patch.stop()
- self.mocks.pop()
-
- def strip_log(self, log):
- log = log.replace("\n\n", "\n")
- return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE)
-
- def read_log_files(self, log_files):
- for log_file in log_files:
- with open(log_file, "r") as f:
- log = self.strip_log(f.read()).split("\n")
- if len(log[-1]) == 0:
- log.pop()
- self.log_files.append(log)
-
-class DatabaseMock:
- def __init__(self, tester, reports, report_db):
- self.tester = tester
- self.reports = reports
- self.report_db = report_db
- self.rows = []
- self.total_report_lines= 0
- self.requests = []
- for user_id, market_id, http_requests, report_lines in self.reports.values():
- self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) )
- self.total_report_lines += len(report_lines)
-
- def start(self):
- connect_mock = mock.Mock()
- self.cursor = mock.MagicMock()
- connect_mock.cursor.return_value = self.cursor
- def _execute(request, *args):
- self.requests.append(request)
- self.cursor.execute.side_effect = _execute
- self.cursor.__iter__.return_value = self.rows
-
- self.db_patch = mock.patch("psycopg2.connect")
- self.db_mock = self.db_patch.start()
- self.db_mock.return_value = connect_mock
-
- def check_calls(self):
- if self.report_db:
- self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count)
- self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count)
- else:
- self.tester.assertEqual(1, self.db_mock.call_count)
- self.tester.assertEqual(1, self.cursor.execute.call_count)
-
- def stop(self):
- self.db_patch.stop()
-
-class RequestsMock:
- def __init__(self, tester):
- self.tester = tester
- self.reports = tester.requests_by_market()
-
- self.last_https = {}
- self.error_calls = []
- self.mocker = requests_mock.Mocker()
- def not_stubbed(*args):
- self.error_calls.append([args[0].method, args[0].url])
- raise requests_mock.exceptions.MockException("Not stubbed URL")
- self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY,
- text=not_stubbed)
-
- self.mocks = {}
-
- for market_id, elements in self.reports.items():
- for element in elements:
- method = element["method"]
- url = element["url"]
- self.mocks \
- .setdefault((method, url), {}) \
- .setdefault(market_id, []) \
- .append(element)
-
- for ((method, url), elements) in self.mocks.items():
- self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True)
-
- def start(self):
- self.mocker.start()
-
- def stop(self):
- self.mocker.stop()
-
- def check_calls(self):
- self.tester.assertEqual([], self.error_calls)
- for (method, url), elements in self.mocks.items():
- for market_id, element in elements.items():
- self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id))
-
- def clean_body(self, body):
- if body is None:
- return None
- if isinstance(body, bytes):
- body = body.decode()
- body = re.sub(r"&nonce=\d*$", "", body)
- body = re.sub(r"nonce=\d*&?", "", body)
- return body
-
-def callback(self, elements, request, context):
- try:
- element = elements[request.headers.get("X-market-id")].pop(0)
- except (IndexError, KeyError):
- self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")])
- raise RuntimeError("Unexpected call")
- if element["response"] is None and element["response_same_as"] is not None:
- element["response"] = self.last_https[element["response_same_as"]]
- elif element["response"] is not None:
- self.last_https[element["date"]] = element["response"]
-
- assert self.clean_body(request.body) == \
- self.clean_body(element["body"]), "Body does not match"
- context.status_code = element["status"]
- if "error" in element:
- if element["error"] == "SSLError":
- raise SSLError(element["error_message"])
- else:
- raise getattr(requests.exceptions, element["error"])(element["error_message"])
- return element["response"]
-
-class GlobalVariablesMock:
- def start(self):
- import market
- import store
-
- self.patchers = [
- mock.patch.multiple(market.Portfolio,
- data=store.LockedVar(None),
- liquidities=store.LockedVar({}),
- last_date=store.LockedVar(None),
- report=store.LockedVar(store.ReportStore(None, no_http_dup=True)),
- worker=None,
- worker_tag="",
- worker_notify=None,
- worker_started=False,
- callback=None)
- ]
- for patcher in self.patchers:
- patcher.start()
-
- def stop(self):
- pass
-
-
-class TimeMock:
- delta = 0
- true_time = time.time
- true_sleep = time.sleep
- time_patch = None
- datetime_patch = None
-
- @classmethod
- def start(cls, start_date):
- cls.delta = (datetime.datetime.now() - start_date).total_seconds()
-
- class fake_datetime(datetime.datetime):
- @classmethod
- def now(cls, tz=None):
- if tz is None:
- return cls.fromtimestamp(time.time())
- else:
- return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz))
-
- cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep)
- cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime)
- cls.time_patch.start()
- cls.datetime_patch.start()
-
- @classmethod
- def stop(cls):
- cls.delta = 0
- cls.datetime_patch.stop()
- cls.time_patch.stop()
-
- @classmethod
- def fake_time(cls):
- return cls.true_time() - cls.delta
-
- @classmethod
- def fake_sleep(cls, duration):
- cls.delta -= duration
- cls.true_sleep(0.2)
-
-class AcceptanceTestCase():
- def parse_file(self, report_file):
- with open(report_file, "rb") as f:
- json_content = json.load(f, parse_float=Decimal)
- config, user, date, market_id = self.parse_config(json_content)
- http_requests = self.parse_requests(json_content)
-
- return config, user, date, market_id, http_requests, json_content
-
- def parse_requests(self, json_content):
- http_requests = []
- for element in json_content:
- if element["type"] != "http_request":
- continue
- http_requests.append(element)
- return http_requests
-
- def parse_config(self, json_content):
- market_info = None
- for element in json_content:
- if element["type"] != "market":
- continue
- market_info = element
- assert market_info is not None, "Couldn't find market element"
-
- args = market_info["args"]
- config = []
- for arg in ["before", "after", "quiet", "debug"]:
- if args.get(arg, False):
- config.append("--{}".format(arg))
- for arg in ["parallel", "report_db"]:
- if not args.get(arg, False):
- config.append("--no-{}".format(arg.replace("_", "-")))
- for action in (args.get("action", []) or []):
- config.extend(["--action", action])
- if args.get("report_path") is not None:
- config.extend(["--report-path", args.get("report_path")])
- if args.get("user") is not None:
- config.extend(["--user", args.get("user")])
- config.extend(["--config", ""])
-
- user = market_info["user_id"]
- date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f")
- market_id = market_info["market_id"]
- return config, user, date, market_id
-
- def requests_by_market(self):
- r = {
- None: []
- }
- got_common = False
- for user_id, market_id, http_requests, report_lines in self.reports.values():
- r[str(market_id)] = []
- for http_request in http_requests:
- if http_request["market_id"] is None:
- if not got_common:
- r[None].append(http_request)
- else:
- r[str(market_id)].append(http_request)
- got_common = True
- return r
-
- def setUp(self):
- if not hasattr(self, "files"):
- raise "This class expects to be inherited with a class defining 'files' variable"
- if not hasattr(self, "log_files"):
- self.log_files = []
-
- self.reports = {}
- self.start_date = datetime.datetime.now()
- self.config = []
- for f in self.files:
- self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f)
- if date < self.start_date:
- self.start_date = date
- self.reports[f] = [user_id, market_id, http_requests, report_lines]
-
- self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config)
- self.requests_mock = RequestsMock(self)
- self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self)
- self.global_variables_mock = GlobalVariablesMock()
-
- self.database_mock.start()
- self.requests_mock.start()
- self.file_mock.start()
- self.global_variables_mock.start()
- TimeMock.start(self.start_date)
-
- def base_test(self):
- main.main(self.config)
- self.requests_mock.check_calls()
- self.database_mock.check_calls()
- self.file_mock.check_calls()
-
- def tearDown(self):
- TimeMock.stop()
- self.global_variables_mock.stop()
- self.file_mock.stop()
- self.requests_mock.stop()
- self.database_mock.stop()
-
-for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True):
- json_files = glob.glob("{}/*.json".format(dirfile))
- log_files = glob.glob("{}/*.log".format(dirfile))
- if len(json_files) > 0:
- name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1]
- cname = "".join(list(map(lambda x: x.capitalize(), name.split("_"))))
-
- globals()[cname] = type(cname,
- (AcceptanceTestCase,unittest.TestCase), {
- "log_files": log_files,
- "files": json_files,
- "test_{}".format(name): AcceptanceTestCase.base_test
- })
-
-if __name__ == '__main__':
- unittest.main()
-