]>
Commit | Line | Data |
---|---|---|
1 | import requests | |
2 | import requests_mock | |
3 | import sys, os | |
4 | import time, datetime | |
5 | import unittest | |
6 | from unittest import mock | |
7 | from ssl import SSLError | |
8 | from decimal import Decimal | |
9 | import simplejson as json | |
10 | import psycopg2 | |
11 | from io import StringIO | |
12 | import re | |
13 | import functools | |
14 | import glob | |
15 | ||
16 | import main | |
17 | ||
18 | class FileMock: | |
19 | def __init__(self, log_files, quiet, tester): | |
20 | self.tester = tester | |
21 | self.log_files = [] | |
22 | if log_files is not None and len(log_files) > 0: | |
23 | self.read_log_files(log_files) | |
24 | self.quiet = quiet | |
25 | self.patches = [ | |
26 | mock.patch("market.open"), | |
27 | mock.patch("os.makedirs"), | |
28 | mock.patch("sys.stdout", new_callable=StringIO), | |
29 | ] | |
30 | self.mocks = [] | |
31 | ||
32 | def start(self): | |
33 | for patch in self.patches: | |
34 | self.mocks.append(patch.start()) | |
35 | self.stdout = self.mocks[-1] | |
36 | ||
37 | def check_calls(self): | |
38 | stdout = self.stdout.getvalue() | |
39 | if self.quiet: | |
40 | self.tester.assertEqual("", stdout) | |
41 | else: | |
42 | log = self.strip_log(stdout) | |
43 | if len(self.log_files) != 0: | |
44 | split_logs = log.split("\n") | |
45 | self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs)) | |
46 | try: | |
47 | for log_file in self.log_files: | |
48 | for line in log_file: | |
49 | split_logs.pop(split_logs.index(line)) | |
50 | except ValueError: | |
51 | if not line.startswith("[Worker] "): | |
52 | self.tester.fail("« {} » not found in log file {}".format(line, split_logs)) | |
53 | # Le fichier de log est écrit | |
54 | # Le fichier de log est printed uniquement si non quiet | |
55 | # Le rapport est écrit si pertinent | |
56 | # Le rapport contient le bon nombre de lignes | |
57 | ||
58 | def stop(self): | |
59 | for patch in self.patches[::-1]: | |
60 | patch.stop() | |
61 | self.mocks.pop() | |
62 | ||
63 | def strip_log(self, log): | |
64 | log = log.replace("\n\n", "\n") | |
65 | return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE) | |
66 | ||
67 | def read_log_files(self, log_files): | |
68 | for log_file in log_files: | |
69 | with open(log_file, "r") as f: | |
70 | log = self.strip_log(f.read()).split("\n") | |
71 | if len(log[-1]) == 0: | |
72 | log.pop() | |
73 | self.log_files.append(log) | |
74 | ||
75 | class DatabaseMock: | |
76 | def __init__(self, tester, reports, report_db): | |
77 | self.tester = tester | |
78 | self.reports = reports | |
79 | self.report_db = report_db | |
80 | self.rows = [] | |
81 | self.total_report_lines= 0 | |
82 | self.requests = [] | |
83 | for user_id, market_id, http_requests, report_lines in self.reports.values(): | |
84 | self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) ) | |
85 | self.total_report_lines += len(report_lines) | |
86 | ||
87 | def start(self): | |
88 | connect_mock = mock.Mock() | |
89 | self.cursor = mock.MagicMock() | |
90 | connect_mock.cursor.return_value = self.cursor | |
91 | def _execute(request, *args): | |
92 | self.requests.append(request) | |
93 | self.cursor.execute.side_effect = _execute | |
94 | self.cursor.__iter__.return_value = self.rows | |
95 | ||
96 | self.db_patch = mock.patch("psycopg2.connect") | |
97 | self.db_mock = self.db_patch.start() | |
98 | self.db_mock.return_value = connect_mock | |
99 | ||
100 | def check_calls(self): | |
101 | if self.report_db: | |
102 | self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count) | |
103 | self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count) | |
104 | else: | |
105 | self.tester.assertEqual(1, self.db_mock.call_count) | |
106 | self.tester.assertEqual(1, self.cursor.execute.call_count) | |
107 | ||
108 | def stop(self): | |
109 | self.db_patch.stop() | |
110 | ||
111 | class RequestsMock: | |
112 | def __init__(self, tester): | |
113 | self.tester = tester | |
114 | self.reports = tester.requests_by_market() | |
115 | ||
116 | self.last_https = {} | |
117 | self.error_calls = [] | |
118 | self.mocker = requests_mock.Mocker() | |
119 | def not_stubbed(*args): | |
120 | self.error_calls.append([args[0].method, args[0].url]) | |
121 | raise requests_mock.exceptions.MockException("Not stubbed URL") | |
122 | self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY, | |
123 | text=not_stubbed) | |
124 | ||
125 | self.mocks = {} | |
126 | ||
127 | for market_id, elements in self.reports.items(): | |
128 | for element in elements: | |
129 | method = element["method"] | |
130 | url = element["url"] | |
131 | self.mocks \ | |
132 | .setdefault((method, url), {}) \ | |
133 | .setdefault(market_id, []) \ | |
134 | .append(element) | |
135 | ||
136 | for ((method, url), elements) in self.mocks.items(): | |
137 | self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True) | |
138 | ||
139 | def start(self): | |
140 | self.mocker.start() | |
141 | ||
142 | def stop(self): | |
143 | self.mocker.stop() | |
144 | ||
145 | def check_calls(self): | |
146 | self.tester.assertEqual([], self.error_calls) | |
147 | for (method, url), elements in self.mocks.items(): | |
148 | for market_id, element in elements.items(): | |
149 | self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id)) | |
150 | ||
151 | def clean_body(self, body): | |
152 | if body is None: | |
153 | return None | |
154 | if isinstance(body, bytes): | |
155 | body = body.decode() | |
156 | body = re.sub(r"&nonce=\d*$", "", body) | |
157 | body = re.sub(r"nonce=\d*&?", "", body) | |
158 | return body | |
159 | ||
160 | def callback(self, elements, request, context): | |
161 | try: | |
162 | element = elements[request.headers.get("X-market-id")].pop(0) | |
163 | except (IndexError, KeyError): | |
164 | self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")]) | |
165 | raise RuntimeError("Unexpected call") | |
166 | if element["response"] is None and element["response_same_as"] is not None: | |
167 | element["response"] = self.last_https[element["response_same_as"]] | |
168 | elif element["response"] is not None: | |
169 | self.last_https[element["date"]] = element["response"] | |
170 | ||
171 | assert self.clean_body(request.body) == \ | |
172 | self.clean_body(element["body"]), "Body does not match" | |
173 | context.status_code = element["status"] | |
174 | if "error" in element: | |
175 | if element["error"] == "SSLError": | |
176 | raise SSLError(element["error_message"]) | |
177 | else: | |
178 | raise getattr(requests.exceptions, element["error"])(element["error_message"]) | |
179 | return element["response"] | |
180 | ||
181 | class GlobalVariablesMock: | |
182 | def start(self): | |
183 | import market | |
184 | import store | |
185 | ||
186 | self.patchers = [ | |
187 | mock.patch.multiple(market.Portfolio, | |
188 | data=store.LockedVar(None), | |
189 | liquidities=store.LockedVar({}), | |
190 | last_date=store.LockedVar(None), | |
191 | report=store.LockedVar(store.ReportStore(None, no_http_dup=True)), | |
192 | worker=None, | |
193 | worker_tag="", | |
194 | worker_notify=None, | |
195 | worker_started=False, | |
196 | callback=None) | |
197 | ] | |
198 | for patcher in self.patchers: | |
199 | patcher.start() | |
200 | ||
201 | def stop(self): | |
202 | pass | |
203 | ||
204 | ||
205 | class TimeMock: | |
206 | delta = 0 | |
207 | true_time = time.time | |
208 | true_sleep = time.sleep | |
209 | time_patch = None | |
210 | datetime_patch = None | |
211 | ||
212 | @classmethod | |
213 | def start(cls, start_date): | |
214 | cls.delta = (datetime.datetime.now() - start_date).total_seconds() | |
215 | ||
216 | class fake_datetime(datetime.datetime): | |
217 | @classmethod | |
218 | def now(cls, tz=None): | |
219 | if tz is None: | |
220 | return cls.fromtimestamp(time.time()) | |
221 | else: | |
222 | return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz)) | |
223 | ||
224 | cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep) | |
225 | cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime) | |
226 | cls.time_patch.start() | |
227 | cls.datetime_patch.start() | |
228 | ||
229 | @classmethod | |
230 | def stop(cls): | |
231 | cls.delta = 0 | |
232 | cls.datetime_patch.stop() | |
233 | cls.time_patch.stop() | |
234 | ||
235 | @classmethod | |
236 | def fake_time(cls): | |
237 | return cls.true_time() - cls.delta | |
238 | ||
239 | @classmethod | |
240 | def fake_sleep(cls, duration): | |
241 | cls.delta -= duration | |
242 | cls.true_sleep(0.2) | |
243 | ||
244 | class AcceptanceTestCase(): | |
245 | def parse_file(self, report_file): | |
246 | with open(report_file, "rb") as f: | |
247 | json_content = json.load(f, parse_float=Decimal) | |
248 | config, user, date, market_id = self.parse_config(json_content) | |
249 | http_requests = self.parse_requests(json_content) | |
250 | ||
251 | return config, user, date, market_id, http_requests, json_content | |
252 | ||
253 | def parse_requests(self, json_content): | |
254 | http_requests = [] | |
255 | for element in json_content: | |
256 | if element["type"] != "http_request": | |
257 | continue | |
258 | http_requests.append(element) | |
259 | return http_requests | |
260 | ||
261 | def parse_config(self, json_content): | |
262 | market_info = None | |
263 | for element in json_content: | |
264 | if element["type"] != "market": | |
265 | continue | |
266 | market_info = element | |
267 | assert market_info is not None, "Couldn't find market element" | |
268 | ||
269 | args = market_info["args"] | |
270 | config = [] | |
271 | for arg in ["before", "after", "quiet", "debug"]: | |
272 | if args.get(arg, False): | |
273 | config.append("--{}".format(arg)) | |
274 | for arg in ["parallel", "report_db"]: | |
275 | if not args.get(arg, False): | |
276 | config.append("--no-{}".format(arg.replace("_", "-"))) | |
277 | for action in (args.get("action", []) or []): | |
278 | config.extend(["--action", action]) | |
279 | if args.get("report_path") is not None: | |
280 | config.extend(["--report-path", args.get("report_path")]) | |
281 | if args.get("user") is not None: | |
282 | config.extend(["--user", args.get("user")]) | |
283 | config.extend(["--config", ""]) | |
284 | ||
285 | user = market_info["user_id"] | |
286 | date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f") | |
287 | market_id = market_info["market_id"] | |
288 | return config, user, date, market_id | |
289 | ||
290 | def requests_by_market(self): | |
291 | r = { | |
292 | None: [] | |
293 | } | |
294 | got_common = False | |
295 | for user_id, market_id, http_requests, report_lines in self.reports.values(): | |
296 | r[str(market_id)] = [] | |
297 | for http_request in http_requests: | |
298 | if http_request["market_id"] is None: | |
299 | if not got_common: | |
300 | r[None].append(http_request) | |
301 | else: | |
302 | r[str(market_id)].append(http_request) | |
303 | got_common = True | |
304 | return r | |
305 | ||
306 | def setUp(self): | |
307 | if not hasattr(self, "files"): | |
308 | raise "This class expects to be inherited with a class defining 'files' variable" | |
309 | if not hasattr(self, "log_files"): | |
310 | self.log_files = [] | |
311 | ||
312 | self.reports = {} | |
313 | self.start_date = datetime.datetime.now() | |
314 | self.config = [] | |
315 | for f in self.files: | |
316 | self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f) | |
317 | if date < self.start_date: | |
318 | self.start_date = date | |
319 | self.reports[f] = [user_id, market_id, http_requests, report_lines] | |
320 | ||
321 | self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config) | |
322 | self.requests_mock = RequestsMock(self) | |
323 | self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self) | |
324 | self.global_variables_mock = GlobalVariablesMock() | |
325 | ||
326 | self.database_mock.start() | |
327 | self.requests_mock.start() | |
328 | self.file_mock.start() | |
329 | self.global_variables_mock.start() | |
330 | TimeMock.start(self.start_date) | |
331 | ||
332 | def base_test(self): | |
333 | main.main(self.config) | |
334 | self.requests_mock.check_calls() | |
335 | self.database_mock.check_calls() | |
336 | self.file_mock.check_calls() | |
337 | ||
338 | def tearDown(self): | |
339 | TimeMock.stop() | |
340 | self.global_variables_mock.stop() | |
341 | self.file_mock.stop() | |
342 | self.requests_mock.stop() | |
343 | self.database_mock.stop() | |
344 | ||
345 | for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True): | |
346 | json_files = glob.glob("{}/*.json".format(dirfile)) | |
347 | log_files = glob.glob("{}/*.log".format(dirfile)) | |
348 | if len(json_files) > 0: | |
349 | name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1] | |
350 | cname = "".join(list(map(lambda x: x.capitalize(), name.split("_")))) | |
351 | ||
352 | globals()[cname] = type(cname, | |
353 | (AcceptanceTestCase,unittest.TestCase), { | |
354 | "log_files": log_files, | |
355 | "files": json_files, | |
356 | "test_{}".format(name): AcceptanceTestCase.base_test | |
357 | }) | |
358 | ||
359 | if __name__ == '__main__': | |
360 | unittest.main() | |
361 |