import simplejson as json
import psycopg2
from io import StringIO
+import re
+import functools
+import glob
+
+import main
class FileMock:
- @classmethod
- def start(cls):
- cls.file_mock = mock.patch("market.open")
- cls.os_mock = mock.patch("os.makedirs")
- cls.stdout_mock = mock.patch('sys.stdout', new_callable=StringIO)
- cls.stdout_mock.start()
- cls.os_mock.start()
- cls.file_mock.start()
+ def __init__(self, log_files, quiet, tester):
+ self.tester = tester
+ self.log_files = []
+ if log_files is not None and len(log_files) > 0:
+ self.read_log_files(log_files)
+ self.quiet = quiet
+ self.patches = [
+ mock.patch("market.open"),
+ mock.patch("os.makedirs"),
+ mock.patch("sys.stdout", new_callable=StringIO),
+ ]
+ self.mocks = []
- @classmethod
- def check_calls(cls, tester):
- pass
- #raise NotImplementedError("Todo")
+ def start(self):
+ for patch in self.patches:
+ self.mocks.append(patch.start())
+ self.stdout = self.mocks[-1]
- @classmethod
- def stop(cls):
- cls.file_mock.stop()
- cls.stdout_mock.stop()
- cls.os_mock.stop()
+ def check_calls(self):
+ stdout = self.stdout.getvalue()
+ if self.quiet:
+ self.tester.assertEqual("", stdout)
+ else:
+ log = self.strip_log(stdout)
+ if len(self.log_files) != 0:
+ split_logs = log.split("\n")
+ self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs))
+ try:
+ for log_file in self.log_files:
+ for line in log_file:
+ split_logs.pop(split_logs.index(line))
+ except ValueError:
+ if not line.startswith("[Worker] "):
+ self.tester.fail("« {} » not found in log file {}".format(line, split_logs))
+ # Le fichier de log est écrit
+ # Le fichier de log est printed uniquement si non quiet
+ # Le rapport est écrit si pertinent
+ # Le rapport contient le bon nombre de lignes
+
+ def stop(self):
+ for patch in self.patches[::-1]:
+ patch.stop()
+ self.mocks.pop()
-class DatabaseMock:
- rows = []
- report_db = False
- db_patch = None
- cursor = None
- total_report_lines = 0
- db_mock = None
- requests = []
+ def strip_log(self, log):
+ log = log.replace("\n\n", "\n")
+ return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE)
- @classmethod
- def start(cls, reports, report_db):
- cls.report_db = report_db
- cls.rows = []
- cls.total_report_lines= 0
- for user_id, market_id, http_requests, report_lines in reports.values():
- cls.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) )
- cls.total_report_lines += len(report_lines)
+ def read_log_files(self, log_files):
+ for log_file in log_files:
+ with open(log_file, "r") as f:
+ log = self.strip_log(f.read()).split("\n")
+ if len(log[-1]) == 0:
+ log.pop()
+ self.log_files.append(log)
+
+class DatabaseMock:
+ def __init__(self, tester, reports, report_db):
+ self.tester = tester
+ self.reports = reports
+ self.report_db = report_db
+ self.rows = []
+ self.total_report_lines= 0
+ self.requests = []
+ for user_id, market_id, http_requests, report_lines in self.reports.values():
+ self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) )
+ self.total_report_lines += len(report_lines)
+ def start(self):
connect_mock = mock.Mock()
- cls.cursor = mock.MagicMock()
- connect_mock.cursor.return_value = cls.cursor
+ self.cursor = mock.MagicMock()
+ connect_mock.cursor.return_value = self.cursor
def _execute(request, *args):
- cls.requests.append(request)
- cls.cursor.execute.side_effect = _execute
- cls.cursor.__iter__.return_value = cls.rows
-
- cls.db_patch = mock.patch("psycopg2.connect")
- cls.db_mock = cls.db_patch.start()
- cls.db_mock.return_value = connect_mock
-
- @classmethod
- def check_calls(cls, tester):
- if cls.report_db:
- tester.assertEqual(1 + len(cls.rows), cls.db_mock.call_count)
- tester.assertEqual(1 + len(cls.rows) + cls.total_report_lines, cls.cursor.execute.call_count)
+ self.requests.append(request)
+ self.cursor.execute.side_effect = _execute
+ self.cursor.__iter__.return_value = self.rows
+
+ self.db_patch = mock.patch("psycopg2.connect")
+ self.db_mock = self.db_patch.start()
+ self.db_mock.return_value = connect_mock
+
+ def check_calls(self):
+ if self.report_db:
+ self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count)
+ self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count)
else:
- tester.assertEqual(1, cls.db_mock.call_count)
- tester.assertEqual(1, cls.cursor.execute.call_count)
+ self.tester.assertEqual(1, self.db_mock.call_count)
+ self.tester.assertEqual(1, self.cursor.execute.call_count)
- @classmethod
- def stop(cls):
- cls.db_patch.stop()
- cls.db_mock = None
- cls.cursor = None
- cls.total_report_lines = 0
- cls.rows = []
- cls.report_db = False
- cls.requests = []
+ def stop(self):
+ self.db_patch.stop()
class RequestsMock:
- adapter = None
- mocks = {}
- last_https = {}
- request_patch = []
-
- @classmethod
- def start(cls, reports):
- cls.adapter = requests_mock.Adapter()
- cls.adapter.register_uri(requests_mock.ANY, requests_mock.ANY,
- exc=requests_mock.exceptions.MockException("Not stubbed URL"))
- true_session = requests.Session
-
- cls.mocks = {}
-
- for market_id, elements in reports.items():
+ def __init__(self, tester):
+ self.tester = tester
+ self.reports = tester.requests_by_market()
+
+ self.last_https = {}
+ self.error_calls = []
+ self.mocker = requests_mock.Mocker()
+ def not_stubbed(*args):
+ self.error_calls.append([args[0].method, args[0].url])
+ raise requests_mock.exceptions.MockException("Not stubbed URL")
+ self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY,
+ text=not_stubbed)
+
+ self.mocks = {}
+
+ for market_id, elements in self.reports.items():
for element in elements:
method = element["method"]
url = element["url"]
- cls.mocks \
+ self.mocks \
.setdefault((method, url), {}) \
.setdefault(market_id, []) \
.append(element)
- for ((method, url), elements) in cls.mocks.items():
- cls.adapter.register_uri(method, url, text=cls.callback_func(elements), complete_qs=True)
- def _session():
- session = true_session()
- session.get_adapter = lambda url: cls.adapter
- return session
- cls.request_patch = [
- mock.patch.object(requests.sessions, "Session", new=_session),
- mock.patch.object(requests, "Session", new=_session)
- ]
- for patch in cls.request_patch:
- patch.start()
+ for ((method, url), elements) in self.mocks.items():
+ self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True)
- @classmethod
- def stop(cls):
- for patch in cls.request_patch:
- patch.stop()
- cls.request_patch = []
- cls.last_https = {}
- cls.mocks = {}
- cls.adapter = None
+ def start(self):
+ self.mocker.start()
- @classmethod
- def check_calls(cls, tester):
- for (method, url), elements in cls.mocks.items():
+ def stop(self):
+ self.mocker.stop()
+
+ def check_calls(self):
+ self.tester.assertEqual([], self.error_calls)
+ for (method, url), elements in self.mocks.items():
for market_id, element in elements.items():
- tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id))
+ self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id))
- @classmethod
- def clean_body(cls, body):
+ def clean_body(self, body):
if body is None:
return None
- import re
if isinstance(body, bytes):
body = body.decode()
body = re.sub(r"&nonce=\d*$", "", body)
body = re.sub(r"nonce=\d*&?", "", body)
return body
- @classmethod
- def callback_func(cls, elements):
- def callback(request, context):
- try:
- element = elements[request.headers.get("X-market-id")].pop(0)
- except (IndexError, KeyError):
- raise RuntimeError("Unexpected call")
- if element["response"] is None and element["response_same_as"] is not None:
- element["response"] = cls.last_https[element["response_same_as"]]
- elif element["response"] is not None:
- cls.last_https[element["date"]] = element["response"]
-
- assert cls.clean_body(request.body) == \
- cls.clean_body(element["body"]), "Body does not match"
- context.status_code = element["status"]
- if "error" in element:
- if element["error"] == "SSLError":
- raise SSLError(element["error_message"])
- else:
- raise getattr(requests.exceptions, element["error"])(element["error_message"])
- return element["response"]
- return callback
+def callback(self, elements, request, context):
+ try:
+ element = elements[request.headers.get("X-market-id")].pop(0)
+ except (IndexError, KeyError):
+ self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")])
+ raise RuntimeError("Unexpected call")
+ if element["response"] is None and element["response_same_as"] is not None:
+ element["response"] = self.last_https[element["response_same_as"]]
+ elif element["response"] is not None:
+ self.last_https[element["date"]] = element["response"]
+
+ assert self.clean_body(request.body) == \
+ self.clean_body(element["body"]), "Body does not match"
+ context.status_code = element["status"]
+ if "error" in element:
+ if element["error"] == "SSLError":
+ raise SSLError(element["error_message"])
+ else:
+ raise getattr(requests.exceptions, element["error"])(element["error_message"])
+ return element["response"]
+
+class GlobalVariablesMock:
+ def start(self):
+ import market
+ import store
+
+ self.patchers = [
+ mock.patch.multiple(market.Portfolio,
+ data=store.LockedVar(None),
+ liquidities=store.LockedVar({}),
+ last_date=store.LockedVar(None),
+ report=store.LockedVar(store.ReportStore(None, no_http_dup=True)),
+ worker=None,
+ worker_tag="",
+ worker_notify=None,
+ worker_started=False,
+ callback=None)
+ ]
+ for patcher in self.patchers:
+ patcher.start()
+
+ def stop(self):
+ pass
+
class TimeMock:
delta = 0
true_time = time.time
+ true_sleep = time.sleep
time_patch = None
datetime_patch = None
@classmethod
def fake_sleep(cls, duration):
cls.delta -= duration
+ cls.true_sleep(0.2)
class AcceptanceTestCase():
def parse_file(self, report_file):
for arg in ["parallel", "report_db"]:
if not args.get(arg, False):
config.append("--no-{}".format(arg.replace("_", "-")))
- for action in args.get("action", []):
+ for action in (args.get("action", []) or []):
config.extend(["--action", action])
if args.get("report_path") is not None:
config.extend(["--report-path", args.get("report_path")])
def setUp(self):
if not hasattr(self, "files"):
- raise "This class expects to be inherited with a class defining self.files in setUp"
+ raise "This class expects to be inherited with a class defining 'files' variable"
+ if not hasattr(self, "log_files"):
+ self.log_files = []
self.reports = {}
self.start_date = datetime.datetime.now()
self.start_date = date
self.reports[f] = [user_id, market_id, http_requests, report_lines]
- DatabaseMock.start(self.reports, "--no-report-db" not in self.config)
- RequestsMock.start(self.requests_by_market())
- FileMock.start()
+ self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config)
+ self.requests_mock = RequestsMock(self)
+ self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self)
+ self.global_variables_mock = GlobalVariablesMock()
+
+ self.database_mock.start()
+ self.requests_mock.start()
+ self.file_mock.start()
+ self.global_variables_mock.start()
TimeMock.start(self.start_date)
def base_test(self):
- import main
main.main(self.config)
- RequestsMock.check_calls(self)
- DatabaseMock.check_calls(self)
- FileMock.check_calls(self)
+ self.requests_mock.check_calls()
+ self.database_mock.check_calls()
+ self.file_mock.check_calls()
def tearDown(self):
TimeMock.stop()
- FileMock.stop()
- RequestsMock.stop()
- DatabaseMock.stop()
+ self.global_variables_mock.stop()
+ self.file_mock.stop()
+ self.requests_mock.stop()
+ self.database_mock.stop()
-import glob
for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True):
json_files = glob.glob("{}/*.json".format(dirfile))
+ log_files = glob.glob("{}/*.log".format(dirfile))
if len(json_files) > 0:
name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1]
cname = "".join(list(map(lambda x: x.capitalize(), name.split("_"))))
globals()[cname] = type(cname,
(AcceptanceTestCase,unittest.TestCase), {
+ "log_files": log_files,
"files": json_files,
"test_{}".format(name): AcceptanceTestCase.base_test
})