diff options
Diffstat (limited to 'tests/acceptance.py')
-rw-r--r-- | tests/acceptance.py | 358 |
1 files changed, 358 insertions, 0 deletions
diff --git a/tests/acceptance.py b/tests/acceptance.py new file mode 100644 index 0000000..66014ca --- /dev/null +++ b/tests/acceptance.py | |||
@@ -0,0 +1,358 @@ | |||
1 | import requests | ||
2 | import requests_mock | ||
3 | import sys, os | ||
4 | import time, datetime | ||
5 | from unittest import mock | ||
6 | from ssl import SSLError | ||
7 | from decimal import Decimal | ||
8 | import simplejson as json | ||
9 | import psycopg2 | ||
10 | from io import StringIO | ||
11 | import re | ||
12 | import functools | ||
13 | import threading | ||
14 | |||
15 | class TimeMock: | ||
16 | delta = {} | ||
17 | delta_init = 0 | ||
18 | true_time = time.time | ||
19 | true_sleep = time.sleep | ||
20 | time_patch = None | ||
21 | datetime_patch = None | ||
22 | |||
23 | @classmethod | ||
24 | def travel(cls, start_date): | ||
25 | cls.delta = {} | ||
26 | cls.delta_init = (datetime.datetime.now() - start_date).total_seconds() | ||
27 | |||
28 | @classmethod | ||
29 | def start(cls): | ||
30 | cls.delta = {} | ||
31 | cls.delta_init = 0 | ||
32 | |||
33 | class fake_datetime(datetime.datetime): | ||
34 | @classmethod | ||
35 | def now(cls, tz=None): | ||
36 | if tz is None: | ||
37 | return cls.fromtimestamp(time.time()) | ||
38 | else: | ||
39 | return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz)) | ||
40 | |||
41 | cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep) | ||
42 | cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime) | ||
43 | cls.time_patch.start() | ||
44 | cls.datetime_patch.start() | ||
45 | |||
46 | @classmethod | ||
47 | def stop(cls): | ||
48 | cls.delta = {} | ||
49 | cls.delta_init = 0 | ||
50 | |||
51 | @classmethod | ||
52 | def fake_time(cls): | ||
53 | cls.delta.setdefault(threading.current_thread(), cls.delta_init) | ||
54 | return cls.true_time() - cls.delta[threading.current_thread()] | ||
55 | |||
56 | @classmethod | ||
57 | def fake_sleep(cls, duration): | ||
58 | cls.delta.setdefault(threading.current_thread(), cls.delta_init) | ||
59 | cls.delta[threading.current_thread()] -= float(duration) | ||
60 | cls.true_sleep(min(float(duration), 0.1)) | ||
61 | |||
62 | TimeMock.start() | ||
63 | import main | ||
64 | |||
65 | class FileMock: | ||
66 | def __init__(self, log_files, quiet, tester): | ||
67 | self.tester = tester | ||
68 | self.log_files = [] | ||
69 | if log_files is not None and len(log_files) > 0: | ||
70 | self.read_log_files(log_files) | ||
71 | self.quiet = quiet | ||
72 | self.patches = [ | ||
73 | mock.patch("market.open"), | ||
74 | mock.patch("os.makedirs"), | ||
75 | mock.patch("sys.stdout", new_callable=StringIO), | ||
76 | ] | ||
77 | self.mocks = [] | ||
78 | |||
79 | def start(self): | ||
80 | for patch in self.patches: | ||
81 | self.mocks.append(patch.start()) | ||
82 | self.stdout = self.mocks[-1] | ||
83 | |||
84 | def check_calls(self): | ||
85 | stdout = self.stdout.getvalue() | ||
86 | if self.quiet: | ||
87 | self.tester.assertEqual("", stdout) | ||
88 | else: | ||
89 | log = self.strip_log(stdout) | ||
90 | if len(self.log_files) != 0: | ||
91 | split_logs = log.split("\n") | ||
92 | self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs)) | ||
93 | try: | ||
94 | for log_file in self.log_files: | ||
95 | for line in log_file: | ||
96 | split_logs.pop(split_logs.index(line)) | ||
97 | except ValueError: | ||
98 | if not line.startswith("[Worker] "): | ||
99 | self.tester.fail("« {} » not found in log file {}".format(line, split_logs)) | ||
100 | # Le fichier de log est écrit | ||
101 | # Le rapport est écrit si pertinent | ||
102 | # Le rapport contient le bon nombre de lignes | ||
103 | |||
104 | def stop(self): | ||
105 | for patch in self.patches[::-1]: | ||
106 | patch.stop() | ||
107 | self.mocks.pop() | ||
108 | |||
109 | def strip_log(self, log): | ||
110 | log = log.replace("\n\n", "\n") | ||
111 | return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE) | ||
112 | |||
113 | def read_log_files(self, log_files): | ||
114 | for log_file in log_files: | ||
115 | with open(log_file, "r") as f: | ||
116 | log = self.strip_log(f.read()).split("\n") | ||
117 | if len(log[-1]) == 0: | ||
118 | log.pop() | ||
119 | self.log_files.append(log) | ||
120 | |||
121 | class DatabaseMock: | ||
122 | def __init__(self, tester, reports, report_db): | ||
123 | self.tester = tester | ||
124 | self.reports = reports | ||
125 | self.report_db = report_db | ||
126 | self.rows = [] | ||
127 | self.total_report_lines= 0 | ||
128 | self.requests = [] | ||
129 | for user_id, market_id, http_requests, report_lines in self.reports.values(): | ||
130 | self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) ) | ||
131 | self.total_report_lines += len(report_lines) | ||
132 | |||
133 | def start(self): | ||
134 | connect_mock = mock.Mock() | ||
135 | self.cursor = mock.MagicMock() | ||
136 | connect_mock.cursor.return_value = self.cursor | ||
137 | def _execute(request, *args): | ||
138 | self.requests.append(request) | ||
139 | self.cursor.execute.side_effect = _execute | ||
140 | self.cursor.__iter__.return_value = self.rows | ||
141 | |||
142 | self.db_patch = mock.patch("psycopg2.connect") | ||
143 | self.db_mock = self.db_patch.start() | ||
144 | self.db_mock.return_value = connect_mock | ||
145 | |||
146 | def check_calls(self): | ||
147 | if self.report_db: | ||
148 | self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count) | ||
149 | self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count) | ||
150 | else: | ||
151 | self.tester.assertEqual(1, self.db_mock.call_count) | ||
152 | self.tester.assertEqual(1, self.cursor.execute.call_count) | ||
153 | |||
154 | def stop(self): | ||
155 | self.db_patch.stop() | ||
156 | |||
157 | class RequestsMock: | ||
158 | def __init__(self, tester): | ||
159 | self.tester = tester | ||
160 | self.reports = tester.requests_by_market() | ||
161 | |||
162 | self.last_https = {} | ||
163 | self.error_calls = [] | ||
164 | self.mocker = requests_mock.Mocker() | ||
165 | def not_stubbed(*args): | ||
166 | self.error_calls.append([args[0].method, args[0].url]) | ||
167 | raise requests_mock.exceptions.MockException("Not stubbed URL") | ||
168 | self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY, | ||
169 | text=not_stubbed) | ||
170 | |||
171 | self.mocks = {} | ||
172 | |||
173 | for market_id, elements in self.reports.items(): | ||
174 | for element in elements: | ||
175 | method = element["method"] | ||
176 | url = element["url"] | ||
177 | self.mocks \ | ||
178 | .setdefault((method, url), {}) \ | ||
179 | .setdefault(market_id, []) \ | ||
180 | .append(element) | ||
181 | |||
182 | for ((method, url), elements) in self.mocks.items(): | ||
183 | self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True) | ||
184 | |||
185 | def start(self): | ||
186 | self.mocker.start() | ||
187 | |||
188 | def stop(self): | ||
189 | self.mocker.stop() | ||
190 | |||
191 | lazy_calls = [ | ||
192 | "https://cryptoportfolio.io/wp-content/uploads/portfolio/json/cryptoportfolio.json", | ||
193 | "https://poloniex.com/public?command=returnTicker", | ||
194 | ] | ||
195 | def check_calls(self): | ||
196 | self.tester.assertEqual([], self.error_calls) | ||
197 | for (method, url), elements in self.mocks.items(): | ||
198 | for market_id, element in elements.items(): | ||
199 | if url not in self.lazy_calls: | ||
200 | self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id)) | ||
201 | |||
202 | def clean_body(self, body): | ||
203 | if body is None: | ||
204 | return None | ||
205 | if isinstance(body, bytes): | ||
206 | body = body.decode() | ||
207 | body = re.sub(r"&nonce=\d*$", "", body) | ||
208 | body = re.sub(r"nonce=\d*&?", "", body) | ||
209 | return body | ||
210 | |||
211 | def callback(self, elements, request, context): | ||
212 | try: | ||
213 | element = elements[request.headers.get("X-market-id")].pop(0) | ||
214 | except (IndexError, KeyError): | ||
215 | self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")]) | ||
216 | raise RuntimeError("Unexpected call") | ||
217 | if element["response"] is None and element["response_same_as"] is not None: | ||
218 | element["response"] = self.last_https[element["response_same_as"]] | ||
219 | elif element["response"] is not None: | ||
220 | self.last_https[element["date"]] = element["response"] | ||
221 | |||
222 | time.sleep(element.get("duration", 0)) | ||
223 | |||
224 | assert self.clean_body(request.body) == \ | ||
225 | self.clean_body(element["body"]), "Body does not match" | ||
226 | context.status_code = element["status"] | ||
227 | if "error" in element: | ||
228 | if element["error"] == "SSLError": | ||
229 | raise SSLError(element["error_message"]) | ||
230 | else: | ||
231 | raise getattr(requests.exceptions, element["error"])(element["error_message"]) | ||
232 | return element["response"] | ||
233 | |||
234 | class GlobalVariablesMock: | ||
235 | def start(self): | ||
236 | import market | ||
237 | import store | ||
238 | |||
239 | self.patchers = [ | ||
240 | mock.patch.multiple(market.Portfolio, | ||
241 | data=store.LockedVar(None), | ||
242 | liquidities=store.LockedVar({}), | ||
243 | last_date=store.LockedVar(None), | ||
244 | report=store.LockedVar(store.ReportStore(None, no_http_dup=True)), | ||
245 | worker=None, | ||
246 | worker_tag="", | ||
247 | worker_notify=None, | ||
248 | worker_started=False, | ||
249 | callback=None) | ||
250 | ] | ||
251 | for patcher in self.patchers: | ||
252 | patcher.start() | ||
253 | |||
254 | def stop(self): | ||
255 | pass | ||
256 | |||
257 | |||
258 | class AcceptanceTestCase(): | ||
259 | def parse_file(self, report_file): | ||
260 | with open(report_file, "rb") as f: | ||
261 | json_content = json.load(f, parse_float=Decimal) | ||
262 | config, user, date, market_id = self.parse_config(json_content) | ||
263 | http_requests = self.parse_requests(json_content) | ||
264 | |||
265 | return config, user, date, market_id, http_requests, json_content | ||
266 | |||
267 | def parse_requests(self, json_content): | ||
268 | http_requests = [] | ||
269 | for element in json_content: | ||
270 | if element["type"] != "http_request": | ||
271 | continue | ||
272 | http_requests.append(element) | ||
273 | return http_requests | ||
274 | |||
275 | def parse_config(self, json_content): | ||
276 | market_info = None | ||
277 | for element in json_content: | ||
278 | if element["type"] != "market": | ||
279 | continue | ||
280 | market_info = element | ||
281 | assert market_info is not None, "Couldn't find market element" | ||
282 | |||
283 | args = market_info["args"] | ||
284 | config = [] | ||
285 | for arg in ["before", "after", "quiet", "debug"]: | ||
286 | if args.get(arg, False): | ||
287 | config.append("--{}".format(arg)) | ||
288 | for arg in ["parallel", "report_db"]: | ||
289 | if not args.get(arg, False): | ||
290 | config.append("--no-{}".format(arg.replace("_", "-"))) | ||
291 | for action in (args.get("action", []) or []): | ||
292 | config.extend(["--action", action]) | ||
293 | if args.get("report_path") is not None: | ||
294 | config.extend(["--report-path", args.get("report_path")]) | ||
295 | if args.get("user") is not None: | ||
296 | config.extend(["--user", args.get("user")]) | ||
297 | config.extend(["--config", ""]) | ||
298 | |||
299 | user = market_info["user_id"] | ||
300 | date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f") | ||
301 | market_id = market_info["market_id"] | ||
302 | return config, user, date, market_id | ||
303 | |||
304 | def requests_by_market(self): | ||
305 | r = { | ||
306 | None: [] | ||
307 | } | ||
308 | got_common = False | ||
309 | for user_id, market_id, http_requests, report_lines in self.reports.values(): | ||
310 | r[str(market_id)] = [] | ||
311 | for http_request in http_requests: | ||
312 | if http_request["market_id"] is None: | ||
313 | if not got_common: | ||
314 | r[None].append(http_request) | ||
315 | else: | ||
316 | r[str(market_id)].append(http_request) | ||
317 | got_common = True | ||
318 | return r | ||
319 | |||
320 | def setUp(self): | ||
321 | if not hasattr(self, "files"): | ||
322 | raise "This class expects to be inherited with a class defining 'files' variable" | ||
323 | if not hasattr(self, "log_files"): | ||
324 | self.log_files = [] | ||
325 | |||
326 | self.reports = {} | ||
327 | self.start_date = datetime.datetime.now() | ||
328 | self.config = [] | ||
329 | for f in self.files: | ||
330 | self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f) | ||
331 | if date < self.start_date: | ||
332 | self.start_date = date | ||
333 | self.reports[f] = [user_id, market_id, http_requests, report_lines] | ||
334 | |||
335 | self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config) | ||
336 | self.requests_mock = RequestsMock(self) | ||
337 | self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self) | ||
338 | self.global_variables_mock = GlobalVariablesMock() | ||
339 | |||
340 | self.database_mock.start() | ||
341 | self.requests_mock.start() | ||
342 | self.file_mock.start() | ||
343 | self.global_variables_mock.start() | ||
344 | TimeMock.travel(self.start_date) | ||
345 | |||
346 | def base_test(self): | ||
347 | main.main(self.config) | ||
348 | self.requests_mock.check_calls() | ||
349 | self.database_mock.check_calls() | ||
350 | self.file_mock.check_calls() | ||
351 | |||
352 | def tearDown(self): | ||
353 | TimeMock.stop() | ||
354 | self.global_variables_mock.stop() | ||
355 | self.file_mock.stop() | ||
356 | self.requests_mock.stop() | ||
357 | self.database_mock.stop() | ||
358 | |||