diff options
author | Ismaël Bouya <ismael.bouya@normalesup.org> | 2018-04-20 20:09:13 +0200 |
---|---|---|
committer | Ismaël Bouya <ismael.bouya@normalesup.org> | 2018-04-20 20:16:18 +0200 |
commit | 3080f31d1ee74104640dcff451922cd0ae88ee22 (patch) | |
tree | 81570ba2eb909b05e7aa4805f5535e47e1df6a11 /test_acceptance.py | |
parent | 9fe90554ff1c8c7aea9e1e1e210419a845579edd (diff) | |
download | Trader-3080f31d1ee74104640dcff451922cd0ae88ee22.tar.gz Trader-3080f31d1ee74104640dcff451922cd0ae88ee22.tar.zst Trader-3080f31d1ee74104640dcff451922cd0ae88ee22.zip |
Move acceptance tests to common directory
Diffstat (limited to 'test_acceptance.py')
-rw-r--r-- | test_acceptance.py | 361 |
1 files changed, 0 insertions, 361 deletions
diff --git a/test_acceptance.py b/test_acceptance.py deleted file mode 100644 index 3633928..0000000 --- a/test_acceptance.py +++ /dev/null | |||
@@ -1,361 +0,0 @@ | |||
1 | import requests | ||
2 | import requests_mock | ||
3 | import sys, os | ||
4 | import time, datetime | ||
5 | import unittest | ||
6 | from unittest import mock | ||
7 | from ssl import SSLError | ||
8 | from decimal import Decimal | ||
9 | import simplejson as json | ||
10 | import psycopg2 | ||
11 | from io import StringIO | ||
12 | import re | ||
13 | import functools | ||
14 | import glob | ||
15 | |||
16 | import main | ||
17 | |||
18 | class FileMock: | ||
19 | def __init__(self, log_files, quiet, tester): | ||
20 | self.tester = tester | ||
21 | self.log_files = [] | ||
22 | if log_files is not None and len(log_files) > 0: | ||
23 | self.read_log_files(log_files) | ||
24 | self.quiet = quiet | ||
25 | self.patches = [ | ||
26 | mock.patch("market.open"), | ||
27 | mock.patch("os.makedirs"), | ||
28 | mock.patch("sys.stdout", new_callable=StringIO), | ||
29 | ] | ||
30 | self.mocks = [] | ||
31 | |||
32 | def start(self): | ||
33 | for patch in self.patches: | ||
34 | self.mocks.append(patch.start()) | ||
35 | self.stdout = self.mocks[-1] | ||
36 | |||
37 | def check_calls(self): | ||
38 | stdout = self.stdout.getvalue() | ||
39 | if self.quiet: | ||
40 | self.tester.assertEqual("", stdout) | ||
41 | else: | ||
42 | log = self.strip_log(stdout) | ||
43 | if len(self.log_files) != 0: | ||
44 | split_logs = log.split("\n") | ||
45 | self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs)) | ||
46 | try: | ||
47 | for log_file in self.log_files: | ||
48 | for line in log_file: | ||
49 | split_logs.pop(split_logs.index(line)) | ||
50 | except ValueError: | ||
51 | if not line.startswith("[Worker] "): | ||
52 | self.tester.fail("« {} » not found in log file {}".format(line, split_logs)) | ||
53 | # Le fichier de log est écrit | ||
54 | # Le fichier de log est printed uniquement si non quiet | ||
55 | # Le rapport est écrit si pertinent | ||
56 | # Le rapport contient le bon nombre de lignes | ||
57 | |||
58 | def stop(self): | ||
59 | for patch in self.patches[::-1]: | ||
60 | patch.stop() | ||
61 | self.mocks.pop() | ||
62 | |||
63 | def strip_log(self, log): | ||
64 | log = log.replace("\n\n", "\n") | ||
65 | return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE) | ||
66 | |||
67 | def read_log_files(self, log_files): | ||
68 | for log_file in log_files: | ||
69 | with open(log_file, "r") as f: | ||
70 | log = self.strip_log(f.read()).split("\n") | ||
71 | if len(log[-1]) == 0: | ||
72 | log.pop() | ||
73 | self.log_files.append(log) | ||
74 | |||
75 | class DatabaseMock: | ||
76 | def __init__(self, tester, reports, report_db): | ||
77 | self.tester = tester | ||
78 | self.reports = reports | ||
79 | self.report_db = report_db | ||
80 | self.rows = [] | ||
81 | self.total_report_lines= 0 | ||
82 | self.requests = [] | ||
83 | for user_id, market_id, http_requests, report_lines in self.reports.values(): | ||
84 | self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) ) | ||
85 | self.total_report_lines += len(report_lines) | ||
86 | |||
87 | def start(self): | ||
88 | connect_mock = mock.Mock() | ||
89 | self.cursor = mock.MagicMock() | ||
90 | connect_mock.cursor.return_value = self.cursor | ||
91 | def _execute(request, *args): | ||
92 | self.requests.append(request) | ||
93 | self.cursor.execute.side_effect = _execute | ||
94 | self.cursor.__iter__.return_value = self.rows | ||
95 | |||
96 | self.db_patch = mock.patch("psycopg2.connect") | ||
97 | self.db_mock = self.db_patch.start() | ||
98 | self.db_mock.return_value = connect_mock | ||
99 | |||
100 | def check_calls(self): | ||
101 | if self.report_db: | ||
102 | self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count) | ||
103 | self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count) | ||
104 | else: | ||
105 | self.tester.assertEqual(1, self.db_mock.call_count) | ||
106 | self.tester.assertEqual(1, self.cursor.execute.call_count) | ||
107 | |||
108 | def stop(self): | ||
109 | self.db_patch.stop() | ||
110 | |||
111 | class RequestsMock: | ||
112 | def __init__(self, tester): | ||
113 | self.tester = tester | ||
114 | self.reports = tester.requests_by_market() | ||
115 | |||
116 | self.last_https = {} | ||
117 | self.error_calls = [] | ||
118 | self.mocker = requests_mock.Mocker() | ||
119 | def not_stubbed(*args): | ||
120 | self.error_calls.append([args[0].method, args[0].url]) | ||
121 | raise requests_mock.exceptions.MockException("Not stubbed URL") | ||
122 | self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY, | ||
123 | text=not_stubbed) | ||
124 | |||
125 | self.mocks = {} | ||
126 | |||
127 | for market_id, elements in self.reports.items(): | ||
128 | for element in elements: | ||
129 | method = element["method"] | ||
130 | url = element["url"] | ||
131 | self.mocks \ | ||
132 | .setdefault((method, url), {}) \ | ||
133 | .setdefault(market_id, []) \ | ||
134 | .append(element) | ||
135 | |||
136 | for ((method, url), elements) in self.mocks.items(): | ||
137 | self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True) | ||
138 | |||
139 | def start(self): | ||
140 | self.mocker.start() | ||
141 | |||
142 | def stop(self): | ||
143 | self.mocker.stop() | ||
144 | |||
145 | def check_calls(self): | ||
146 | self.tester.assertEqual([], self.error_calls) | ||
147 | for (method, url), elements in self.mocks.items(): | ||
148 | for market_id, element in elements.items(): | ||
149 | self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id)) | ||
150 | |||
151 | def clean_body(self, body): | ||
152 | if body is None: | ||
153 | return None | ||
154 | if isinstance(body, bytes): | ||
155 | body = body.decode() | ||
156 | body = re.sub(r"&nonce=\d*$", "", body) | ||
157 | body = re.sub(r"nonce=\d*&?", "", body) | ||
158 | return body | ||
159 | |||
160 | def callback(self, elements, request, context): | ||
161 | try: | ||
162 | element = elements[request.headers.get("X-market-id")].pop(0) | ||
163 | except (IndexError, KeyError): | ||
164 | self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")]) | ||
165 | raise RuntimeError("Unexpected call") | ||
166 | if element["response"] is None and element["response_same_as"] is not None: | ||
167 | element["response"] = self.last_https[element["response_same_as"]] | ||
168 | elif element["response"] is not None: | ||
169 | self.last_https[element["date"]] = element["response"] | ||
170 | |||
171 | assert self.clean_body(request.body) == \ | ||
172 | self.clean_body(element["body"]), "Body does not match" | ||
173 | context.status_code = element["status"] | ||
174 | if "error" in element: | ||
175 | if element["error"] == "SSLError": | ||
176 | raise SSLError(element["error_message"]) | ||
177 | else: | ||
178 | raise getattr(requests.exceptions, element["error"])(element["error_message"]) | ||
179 | return element["response"] | ||
180 | |||
181 | class GlobalVariablesMock: | ||
182 | def start(self): | ||
183 | import market | ||
184 | import store | ||
185 | |||
186 | self.patchers = [ | ||
187 | mock.patch.multiple(market.Portfolio, | ||
188 | data=store.LockedVar(None), | ||
189 | liquidities=store.LockedVar({}), | ||
190 | last_date=store.LockedVar(None), | ||
191 | report=store.LockedVar(store.ReportStore(None, no_http_dup=True)), | ||
192 | worker=None, | ||
193 | worker_tag="", | ||
194 | worker_notify=None, | ||
195 | worker_started=False, | ||
196 | callback=None) | ||
197 | ] | ||
198 | for patcher in self.patchers: | ||
199 | patcher.start() | ||
200 | |||
201 | def stop(self): | ||
202 | pass | ||
203 | |||
204 | |||
205 | class TimeMock: | ||
206 | delta = 0 | ||
207 | true_time = time.time | ||
208 | true_sleep = time.sleep | ||
209 | time_patch = None | ||
210 | datetime_patch = None | ||
211 | |||
212 | @classmethod | ||
213 | def start(cls, start_date): | ||
214 | cls.delta = (datetime.datetime.now() - start_date).total_seconds() | ||
215 | |||
216 | class fake_datetime(datetime.datetime): | ||
217 | @classmethod | ||
218 | def now(cls, tz=None): | ||
219 | if tz is None: | ||
220 | return cls.fromtimestamp(time.time()) | ||
221 | else: | ||
222 | return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz)) | ||
223 | |||
224 | cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep) | ||
225 | cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime) | ||
226 | cls.time_patch.start() | ||
227 | cls.datetime_patch.start() | ||
228 | |||
229 | @classmethod | ||
230 | def stop(cls): | ||
231 | cls.delta = 0 | ||
232 | cls.datetime_patch.stop() | ||
233 | cls.time_patch.stop() | ||
234 | |||
235 | @classmethod | ||
236 | def fake_time(cls): | ||
237 | return cls.true_time() - cls.delta | ||
238 | |||
239 | @classmethod | ||
240 | def fake_sleep(cls, duration): | ||
241 | cls.delta -= duration | ||
242 | cls.true_sleep(0.2) | ||
243 | |||
244 | class AcceptanceTestCase(): | ||
245 | def parse_file(self, report_file): | ||
246 | with open(report_file, "rb") as f: | ||
247 | json_content = json.load(f, parse_float=Decimal) | ||
248 | config, user, date, market_id = self.parse_config(json_content) | ||
249 | http_requests = self.parse_requests(json_content) | ||
250 | |||
251 | return config, user, date, market_id, http_requests, json_content | ||
252 | |||
253 | def parse_requests(self, json_content): | ||
254 | http_requests = [] | ||
255 | for element in json_content: | ||
256 | if element["type"] != "http_request": | ||
257 | continue | ||
258 | http_requests.append(element) | ||
259 | return http_requests | ||
260 | |||
261 | def parse_config(self, json_content): | ||
262 | market_info = None | ||
263 | for element in json_content: | ||
264 | if element["type"] != "market": | ||
265 | continue | ||
266 | market_info = element | ||
267 | assert market_info is not None, "Couldn't find market element" | ||
268 | |||
269 | args = market_info["args"] | ||
270 | config = [] | ||
271 | for arg in ["before", "after", "quiet", "debug"]: | ||
272 | if args.get(arg, False): | ||
273 | config.append("--{}".format(arg)) | ||
274 | for arg in ["parallel", "report_db"]: | ||
275 | if not args.get(arg, False): | ||
276 | config.append("--no-{}".format(arg.replace("_", "-"))) | ||
277 | for action in (args.get("action", []) or []): | ||
278 | config.extend(["--action", action]) | ||
279 | if args.get("report_path") is not None: | ||
280 | config.extend(["--report-path", args.get("report_path")]) | ||
281 | if args.get("user") is not None: | ||
282 | config.extend(["--user", args.get("user")]) | ||
283 | config.extend(["--config", ""]) | ||
284 | |||
285 | user = market_info["user_id"] | ||
286 | date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f") | ||
287 | market_id = market_info["market_id"] | ||
288 | return config, user, date, market_id | ||
289 | |||
290 | def requests_by_market(self): | ||
291 | r = { | ||
292 | None: [] | ||
293 | } | ||
294 | got_common = False | ||
295 | for user_id, market_id, http_requests, report_lines in self.reports.values(): | ||
296 | r[str(market_id)] = [] | ||
297 | for http_request in http_requests: | ||
298 | if http_request["market_id"] is None: | ||
299 | if not got_common: | ||
300 | r[None].append(http_request) | ||
301 | else: | ||
302 | r[str(market_id)].append(http_request) | ||
303 | got_common = True | ||
304 | return r | ||
305 | |||
306 | def setUp(self): | ||
307 | if not hasattr(self, "files"): | ||
308 | raise "This class expects to be inherited with a class defining 'files' variable" | ||
309 | if not hasattr(self, "log_files"): | ||
310 | self.log_files = [] | ||
311 | |||
312 | self.reports = {} | ||
313 | self.start_date = datetime.datetime.now() | ||
314 | self.config = [] | ||
315 | for f in self.files: | ||
316 | self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f) | ||
317 | if date < self.start_date: | ||
318 | self.start_date = date | ||
319 | self.reports[f] = [user_id, market_id, http_requests, report_lines] | ||
320 | |||
321 | self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config) | ||
322 | self.requests_mock = RequestsMock(self) | ||
323 | self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self) | ||
324 | self.global_variables_mock = GlobalVariablesMock() | ||
325 | |||
326 | self.database_mock.start() | ||
327 | self.requests_mock.start() | ||
328 | self.file_mock.start() | ||
329 | self.global_variables_mock.start() | ||
330 | TimeMock.start(self.start_date) | ||
331 | |||
332 | def base_test(self): | ||
333 | main.main(self.config) | ||
334 | self.requests_mock.check_calls() | ||
335 | self.database_mock.check_calls() | ||
336 | self.file_mock.check_calls() | ||
337 | |||
338 | def tearDown(self): | ||
339 | TimeMock.stop() | ||
340 | self.global_variables_mock.stop() | ||
341 | self.file_mock.stop() | ||
342 | self.requests_mock.stop() | ||
343 | self.database_mock.stop() | ||
344 | |||
345 | for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True): | ||
346 | json_files = glob.glob("{}/*.json".format(dirfile)) | ||
347 | log_files = glob.glob("{}/*.log".format(dirfile)) | ||
348 | if len(json_files) > 0: | ||
349 | name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1] | ||
350 | cname = "".join(list(map(lambda x: x.capitalize(), name.split("_")))) | ||
351 | |||
352 | globals()[cname] = type(cname, | ||
353 | (AcceptanceTestCase,unittest.TestCase), { | ||
354 | "log_files": log_files, | ||
355 | "files": json_files, | ||
356 | "test_{}".format(name): AcceptanceTestCase.base_test | ||
357 | }) | ||
358 | |||
359 | if __name__ == '__main__': | ||
360 | unittest.main() | ||
361 | |||