]> git.immae.eu Git - perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git/blame_incremental - test_acceptance.py
Merge branch 'dev'
[perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git] / test_acceptance.py
... / ...
CommitLineData
1import requests
2import requests_mock
3import sys, os
4import time, datetime
5import unittest
6from unittest import mock
7from ssl import SSLError
8from decimal import Decimal
9import simplejson as json
10import psycopg2
11from io import StringIO
12import re
13import functools
14import glob
15
16import main
17
18class FileMock:
19 def __init__(self, log_files, quiet, tester):
20 self.tester = tester
21 self.log_files = []
22 if log_files is not None and len(log_files) > 0:
23 self.read_log_files(log_files)
24 self.quiet = quiet
25 self.patches = [
26 mock.patch("market.open"),
27 mock.patch("os.makedirs"),
28 mock.patch("sys.stdout", new_callable=StringIO),
29 ]
30 self.mocks = []
31
32 def start(self):
33 for patch in self.patches:
34 self.mocks.append(patch.start())
35 self.stdout = self.mocks[-1]
36
37 def check_calls(self):
38 stdout = self.stdout.getvalue()
39 if self.quiet:
40 self.tester.assertEqual("", stdout)
41 else:
42 log = self.strip_log(stdout)
43 if len(self.log_files) != 0:
44 split_logs = log.split("\n")
45 self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs))
46 try:
47 for log_file in self.log_files:
48 for line in log_file:
49 split_logs.pop(split_logs.index(line))
50 except ValueError:
51 if not line.startswith("[Worker] "):
52 self.tester.fail("« {} » not found in log file {}".format(line, split_logs))
53 # Le fichier de log est écrit
54 # Le fichier de log est printed uniquement si non quiet
55 # Le rapport est écrit si pertinent
56 # Le rapport contient le bon nombre de lignes
57
58 def stop(self):
59 for patch in self.patches[::-1]:
60 patch.stop()
61 self.mocks.pop()
62
63 def strip_log(self, log):
64 log = log.replace("\n\n", "\n")
65 return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE)
66
67 def read_log_files(self, log_files):
68 for log_file in log_files:
69 with open(log_file, "r") as f:
70 log = self.strip_log(f.read()).split("\n")
71 if len(log[-1]) == 0:
72 log.pop()
73 self.log_files.append(log)
74
75class DatabaseMock:
76 def __init__(self, tester, reports, report_db):
77 self.tester = tester
78 self.reports = reports
79 self.report_db = report_db
80 self.rows = []
81 self.total_report_lines= 0
82 self.requests = []
83 for user_id, market_id, http_requests, report_lines in self.reports.values():
84 self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) )
85 self.total_report_lines += len(report_lines)
86
87 def start(self):
88 connect_mock = mock.Mock()
89 self.cursor = mock.MagicMock()
90 connect_mock.cursor.return_value = self.cursor
91 def _execute(request, *args):
92 self.requests.append(request)
93 self.cursor.execute.side_effect = _execute
94 self.cursor.__iter__.return_value = self.rows
95
96 self.db_patch = mock.patch("psycopg2.connect")
97 self.db_mock = self.db_patch.start()
98 self.db_mock.return_value = connect_mock
99
100 def check_calls(self):
101 if self.report_db:
102 self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count)
103 self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count)
104 else:
105 self.tester.assertEqual(1, self.db_mock.call_count)
106 self.tester.assertEqual(1, self.cursor.execute.call_count)
107
108 def stop(self):
109 self.db_patch.stop()
110
111class RequestsMock:
112 def __init__(self, tester):
113 self.tester = tester
114 self.reports = tester.requests_by_market()
115
116 self.last_https = {}
117 self.error_calls = []
118 self.mocker = requests_mock.Mocker()
119 def not_stubbed(*args):
120 self.error_calls.append([args[0].method, args[0].url])
121 raise requests_mock.exceptions.MockException("Not stubbed URL")
122 self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY,
123 text=not_stubbed)
124
125 self.mocks = {}
126
127 for market_id, elements in self.reports.items():
128 for element in elements:
129 method = element["method"]
130 url = element["url"]
131 self.mocks \
132 .setdefault((method, url), {}) \
133 .setdefault(market_id, []) \
134 .append(element)
135
136 for ((method, url), elements) in self.mocks.items():
137 self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True)
138
139 def start(self):
140 self.mocker.start()
141
142 def stop(self):
143 self.mocker.stop()
144
145 def check_calls(self):
146 self.tester.assertEqual([], self.error_calls)
147 for (method, url), elements in self.mocks.items():
148 for market_id, element in elements.items():
149 self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id))
150
151 def clean_body(self, body):
152 if body is None:
153 return None
154 if isinstance(body, bytes):
155 body = body.decode()
156 body = re.sub(r"&nonce=\d*$", "", body)
157 body = re.sub(r"nonce=\d*&?", "", body)
158 return body
159
160def callback(self, elements, request, context):
161 try:
162 element = elements[request.headers.get("X-market-id")].pop(0)
163 except (IndexError, KeyError):
164 self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")])
165 raise RuntimeError("Unexpected call")
166 if element["response"] is None and element["response_same_as"] is not None:
167 element["response"] = self.last_https[element["response_same_as"]]
168 elif element["response"] is not None:
169 self.last_https[element["date"]] = element["response"]
170
171 assert self.clean_body(request.body) == \
172 self.clean_body(element["body"]), "Body does not match"
173 context.status_code = element["status"]
174 if "error" in element:
175 if element["error"] == "SSLError":
176 raise SSLError(element["error_message"])
177 else:
178 raise getattr(requests.exceptions, element["error"])(element["error_message"])
179 return element["response"]
180
181class GlobalVariablesMock:
182 def start(self):
183 import market
184 import store
185
186 self.patchers = [
187 mock.patch.multiple(market.Portfolio,
188 data=store.LockedVar(None),
189 liquidities=store.LockedVar({}),
190 last_date=store.LockedVar(None),
191 report=store.LockedVar(store.ReportStore(None, no_http_dup=True)),
192 worker=None,
193 worker_tag="",
194 worker_notify=None,
195 worker_started=False,
196 callback=None)
197 ]
198 for patcher in self.patchers:
199 patcher.start()
200
201 def stop(self):
202 pass
203
204
205class TimeMock:
206 delta = 0
207 true_time = time.time
208 true_sleep = time.sleep
209 time_patch = None
210 datetime_patch = None
211
212 @classmethod
213 def start(cls, start_date):
214 cls.delta = (datetime.datetime.now() - start_date).total_seconds()
215
216 class fake_datetime(datetime.datetime):
217 @classmethod
218 def now(cls, tz=None):
219 if tz is None:
220 return cls.fromtimestamp(time.time())
221 else:
222 return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz))
223
224 cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep)
225 cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime)
226 cls.time_patch.start()
227 cls.datetime_patch.start()
228
229 @classmethod
230 def stop(cls):
231 cls.delta = 0
232 cls.datetime_patch.stop()
233 cls.time_patch.stop()
234
235 @classmethod
236 def fake_time(cls):
237 return cls.true_time() - cls.delta
238
239 @classmethod
240 def fake_sleep(cls, duration):
241 cls.delta -= duration
242 cls.true_sleep(0.2)
243
244class AcceptanceTestCase():
245 def parse_file(self, report_file):
246 with open(report_file, "rb") as f:
247 json_content = json.load(f, parse_float=Decimal)
248 config, user, date, market_id = self.parse_config(json_content)
249 http_requests = self.parse_requests(json_content)
250
251 return config, user, date, market_id, http_requests, json_content
252
253 def parse_requests(self, json_content):
254 http_requests = []
255 for element in json_content:
256 if element["type"] != "http_request":
257 continue
258 http_requests.append(element)
259 return http_requests
260
261 def parse_config(self, json_content):
262 market_info = None
263 for element in json_content:
264 if element["type"] != "market":
265 continue
266 market_info = element
267 assert market_info is not None, "Couldn't find market element"
268
269 args = market_info["args"]
270 config = []
271 for arg in ["before", "after", "quiet", "debug"]:
272 if args.get(arg, False):
273 config.append("--{}".format(arg))
274 for arg in ["parallel", "report_db"]:
275 if not args.get(arg, False):
276 config.append("--no-{}".format(arg.replace("_", "-")))
277 for action in (args.get("action", []) or []):
278 config.extend(["--action", action])
279 if args.get("report_path") is not None:
280 config.extend(["--report-path", args.get("report_path")])
281 if args.get("user") is not None:
282 config.extend(["--user", args.get("user")])
283 config.extend(["--config", ""])
284
285 user = market_info["user_id"]
286 date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f")
287 market_id = market_info["market_id"]
288 return config, user, date, market_id
289
290 def requests_by_market(self):
291 r = {
292 None: []
293 }
294 got_common = False
295 for user_id, market_id, http_requests, report_lines in self.reports.values():
296 r[str(market_id)] = []
297 for http_request in http_requests:
298 if http_request["market_id"] is None:
299 if not got_common:
300 r[None].append(http_request)
301 else:
302 r[str(market_id)].append(http_request)
303 got_common = True
304 return r
305
306 def setUp(self):
307 if not hasattr(self, "files"):
308 raise "This class expects to be inherited with a class defining 'files' variable"
309 if not hasattr(self, "log_files"):
310 self.log_files = []
311
312 self.reports = {}
313 self.start_date = datetime.datetime.now()
314 self.config = []
315 for f in self.files:
316 self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f)
317 if date < self.start_date:
318 self.start_date = date
319 self.reports[f] = [user_id, market_id, http_requests, report_lines]
320
321 self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config)
322 self.requests_mock = RequestsMock(self)
323 self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self)
324 self.global_variables_mock = GlobalVariablesMock()
325
326 self.database_mock.start()
327 self.requests_mock.start()
328 self.file_mock.start()
329 self.global_variables_mock.start()
330 TimeMock.start(self.start_date)
331
332 def base_test(self):
333 main.main(self.config)
334 self.requests_mock.check_calls()
335 self.database_mock.check_calls()
336 self.file_mock.check_calls()
337
338 def tearDown(self):
339 TimeMock.stop()
340 self.global_variables_mock.stop()
341 self.file_mock.stop()
342 self.requests_mock.stop()
343 self.database_mock.stop()
344
345for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True):
346 json_files = glob.glob("{}/*.json".format(dirfile))
347 log_files = glob.glob("{}/*.log".format(dirfile))
348 if len(json_files) > 0:
349 name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1]
350 cname = "".join(list(map(lambda x: x.capitalize(), name.split("_"))))
351
352 globals()[cname] = type(cname,
353 (AcceptanceTestCase,unittest.TestCase), {
354 "log_files": log_files,
355 "files": json_files,
356 "test_{}".format(name): AcceptanceTestCase.base_test
357 })
358
359if __name__ == '__main__':
360 unittest.main()
361