]>
Commit | Line | Data |
---|---|---|
1d72880c IB |
1 | import requests |
2 | import requests_mock | |
3 | import sys, os | |
4 | import time, datetime | |
1d72880c IB |
5 | from unittest import mock |
6 | from ssl import SSLError | |
7 | from decimal import Decimal | |
8 | import simplejson as json | |
9 | import psycopg2 | |
10 | from io import StringIO | |
a0dcf4e0 IB |
11 | import re |
12 | import functools | |
3080f31d | 13 | import threading |
a0dcf4e0 | 14 | |
3080f31d IB |
15 | class TimeMock: |
16 | delta = {} | |
17 | delta_init = 0 | |
18 | true_time = time.time | |
19 | true_sleep = time.sleep | |
20 | time_patch = None | |
21 | datetime_patch = None | |
22 | ||
23 | @classmethod | |
24 | def travel(cls, start_date): | |
25 | cls.delta = {} | |
26 | cls.delta_init = (datetime.datetime.now() - start_date).total_seconds() | |
27 | ||
28 | @classmethod | |
29 | def start(cls): | |
30 | cls.delta = {} | |
31 | cls.delta_init = 0 | |
32 | ||
33 | class fake_datetime(datetime.datetime): | |
34 | @classmethod | |
35 | def now(cls, tz=None): | |
36 | if tz is None: | |
37 | return cls.fromtimestamp(time.time()) | |
38 | else: | |
39 | return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz)) | |
40 | ||
41 | cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep) | |
42 | cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime) | |
43 | cls.time_patch.start() | |
44 | cls.datetime_patch.start() | |
45 | ||
46 | @classmethod | |
47 | def stop(cls): | |
48 | cls.delta = {} | |
49 | cls.delta_init = 0 | |
50 | ||
51 | @classmethod | |
52 | def fake_time(cls): | |
53 | cls.delta.setdefault(threading.current_thread(), cls.delta_init) | |
54 | return cls.true_time() - cls.delta[threading.current_thread()] | |
55 | ||
56 | @classmethod | |
57 | def fake_sleep(cls, duration): | |
58 | cls.delta.setdefault(threading.current_thread(), cls.delta_init) | |
59 | cls.delta[threading.current_thread()] -= float(duration) | |
60 | cls.true_sleep(min(float(duration), 0.1)) | |
61 | ||
62 | TimeMock.start() | |
a0dcf4e0 | 63 | import main |
1d72880c IB |
64 | |
65 | class FileMock: | |
a0dcf4e0 IB |
66 | def __init__(self, log_files, quiet, tester): |
67 | self.tester = tester | |
68 | self.log_files = [] | |
69 | if log_files is not None and len(log_files) > 0: | |
70 | self.read_log_files(log_files) | |
71 | self.quiet = quiet | |
72 | self.patches = [ | |
73 | mock.patch("market.open"), | |
74 | mock.patch("os.makedirs"), | |
75 | mock.patch("sys.stdout", new_callable=StringIO), | |
76 | ] | |
77 | self.mocks = [] | |
1d72880c | 78 | |
a0dcf4e0 IB |
79 | def start(self): |
80 | for patch in self.patches: | |
81 | self.mocks.append(patch.start()) | |
82 | self.stdout = self.mocks[-1] | |
1d72880c | 83 | |
a0dcf4e0 IB |
84 | def check_calls(self): |
85 | stdout = self.stdout.getvalue() | |
86 | if self.quiet: | |
87 | self.tester.assertEqual("", stdout) | |
88 | else: | |
89 | log = self.strip_log(stdout) | |
90 | if len(self.log_files) != 0: | |
91 | split_logs = log.split("\n") | |
92 | self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs)) | |
93 | try: | |
94 | for log_file in self.log_files: | |
95 | for line in log_file: | |
96 | split_logs.pop(split_logs.index(line)) | |
97 | except ValueError: | |
98 | if not line.startswith("[Worker] "): | |
99 | self.tester.fail("« {} » not found in log file {}".format(line, split_logs)) | |
100 | # Le fichier de log est écrit | |
a0dcf4e0 IB |
101 | # Le rapport est écrit si pertinent |
102 | # Le rapport contient le bon nombre de lignes | |
103 | ||
104 | def stop(self): | |
105 | for patch in self.patches[::-1]: | |
106 | patch.stop() | |
107 | self.mocks.pop() | |
1d72880c | 108 | |
a0dcf4e0 IB |
109 | def strip_log(self, log): |
110 | log = log.replace("\n\n", "\n") | |
111 | return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE) | |
1d72880c | 112 | |
a0dcf4e0 IB |
113 | def read_log_files(self, log_files): |
114 | for log_file in log_files: | |
115 | with open(log_file, "r") as f: | |
116 | log = self.strip_log(f.read()).split("\n") | |
117 | if len(log[-1]) == 0: | |
118 | log.pop() | |
119 | self.log_files.append(log) | |
120 | ||
121 | class DatabaseMock: | |
122 | def __init__(self, tester, reports, report_db): | |
123 | self.tester = tester | |
124 | self.reports = reports | |
125 | self.report_db = report_db | |
126 | self.rows = [] | |
127 | self.total_report_lines= 0 | |
128 | self.requests = [] | |
129 | for user_id, market_id, http_requests, report_lines in self.reports.values(): | |
130 | self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) ) | |
131 | self.total_report_lines += len(report_lines) | |
1d72880c | 132 | |
a0dcf4e0 | 133 | def start(self): |
1d72880c | 134 | connect_mock = mock.Mock() |
a0dcf4e0 IB |
135 | self.cursor = mock.MagicMock() |
136 | connect_mock.cursor.return_value = self.cursor | |
1d72880c | 137 | def _execute(request, *args): |
a0dcf4e0 IB |
138 | self.requests.append(request) |
139 | self.cursor.execute.side_effect = _execute | |
140 | self.cursor.__iter__.return_value = self.rows | |
141 | ||
142 | self.db_patch = mock.patch("psycopg2.connect") | |
143 | self.db_mock = self.db_patch.start() | |
144 | self.db_mock.return_value = connect_mock | |
145 | ||
146 | def check_calls(self): | |
147 | if self.report_db: | |
148 | self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count) | |
149 | self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count) | |
1d72880c | 150 | else: |
a0dcf4e0 IB |
151 | self.tester.assertEqual(1, self.db_mock.call_count) |
152 | self.tester.assertEqual(1, self.cursor.execute.call_count) | |
1d72880c | 153 | |
a0dcf4e0 IB |
154 | def stop(self): |
155 | self.db_patch.stop() | |
1d72880c IB |
156 | |
157 | class RequestsMock: | |
a0dcf4e0 IB |
158 | def __init__(self, tester): |
159 | self.tester = tester | |
160 | self.reports = tester.requests_by_market() | |
161 | ||
162 | self.last_https = {} | |
163 | self.error_calls = [] | |
164 | self.mocker = requests_mock.Mocker() | |
165 | def not_stubbed(*args): | |
166 | self.error_calls.append([args[0].method, args[0].url]) | |
167 | raise requests_mock.exceptions.MockException("Not stubbed URL") | |
168 | self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY, | |
169 | text=not_stubbed) | |
170 | ||
171 | self.mocks = {} | |
172 | ||
173 | for market_id, elements in self.reports.items(): | |
1d72880c IB |
174 | for element in elements: |
175 | method = element["method"] | |
176 | url = element["url"] | |
a0dcf4e0 | 177 | self.mocks \ |
1d72880c IB |
178 | .setdefault((method, url), {}) \ |
179 | .setdefault(market_id, []) \ | |
180 | .append(element) | |
181 | ||
a0dcf4e0 IB |
182 | for ((method, url), elements) in self.mocks.items(): |
183 | self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True) | |
1d72880c | 184 | |
a0dcf4e0 IB |
185 | def start(self): |
186 | self.mocker.start() | |
1d72880c | 187 | |
a0dcf4e0 IB |
188 | def stop(self): |
189 | self.mocker.stop() | |
190 | ||
3080f31d IB |
191 | lazy_calls = [ |
192 | "https://cryptoportfolio.io/wp-content/uploads/portfolio/json/cryptoportfolio.json", | |
193 | "https://poloniex.com/public?command=returnTicker", | |
194 | ] | |
a0dcf4e0 IB |
195 | def check_calls(self): |
196 | self.tester.assertEqual([], self.error_calls) | |
197 | for (method, url), elements in self.mocks.items(): | |
1d72880c | 198 | for market_id, element in elements.items(): |
3080f31d IB |
199 | if url not in self.lazy_calls: |
200 | self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id)) | |
1d72880c | 201 | |
a0dcf4e0 | 202 | def clean_body(self, body): |
1d72880c IB |
203 | if body is None: |
204 | return None | |
1d72880c IB |
205 | if isinstance(body, bytes): |
206 | body = body.decode() | |
207 | body = re.sub(r"&nonce=\d*$", "", body) | |
208 | body = re.sub(r"nonce=\d*&?", "", body) | |
209 | return body | |
210 | ||
a0dcf4e0 IB |
211 | def callback(self, elements, request, context): |
212 | try: | |
213 | element = elements[request.headers.get("X-market-id")].pop(0) | |
214 | except (IndexError, KeyError): | |
215 | self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")]) | |
216 | raise RuntimeError("Unexpected call") | |
217 | if element["response"] is None and element["response_same_as"] is not None: | |
218 | element["response"] = self.last_https[element["response_same_as"]] | |
219 | elif element["response"] is not None: | |
220 | self.last_https[element["date"]] = element["response"] | |
221 | ||
3080f31d IB |
222 | time.sleep(element.get("duration", 0)) |
223 | ||
a0dcf4e0 IB |
224 | assert self.clean_body(request.body) == \ |
225 | self.clean_body(element["body"]), "Body does not match" | |
226 | context.status_code = element["status"] | |
227 | if "error" in element: | |
228 | if element["error"] == "SSLError": | |
229 | raise SSLError(element["error_message"]) | |
230 | else: | |
231 | raise getattr(requests.exceptions, element["error"])(element["error_message"]) | |
232 | return element["response"] | |
233 | ||
234 | class GlobalVariablesMock: | |
235 | def start(self): | |
236 | import market | |
237 | import store | |
238 | ||
239 | self.patchers = [ | |
240 | mock.patch.multiple(market.Portfolio, | |
241 | data=store.LockedVar(None), | |
242 | liquidities=store.LockedVar({}), | |
243 | last_date=store.LockedVar(None), | |
244 | report=store.LockedVar(store.ReportStore(None, no_http_dup=True)), | |
245 | worker=None, | |
246 | worker_tag="", | |
247 | worker_notify=None, | |
248 | worker_started=False, | |
249 | callback=None) | |
250 | ] | |
251 | for patcher in self.patchers: | |
252 | patcher.start() | |
253 | ||
254 | def stop(self): | |
255 | pass | |
256 | ||
1d72880c | 257 | |
1d72880c IB |
258 | class AcceptanceTestCase(): |
259 | def parse_file(self, report_file): | |
260 | with open(report_file, "rb") as f: | |
261 | json_content = json.load(f, parse_float=Decimal) | |
262 | config, user, date, market_id = self.parse_config(json_content) | |
263 | http_requests = self.parse_requests(json_content) | |
264 | ||
265 | return config, user, date, market_id, http_requests, json_content | |
266 | ||
267 | def parse_requests(self, json_content): | |
268 | http_requests = [] | |
269 | for element in json_content: | |
270 | if element["type"] != "http_request": | |
271 | continue | |
272 | http_requests.append(element) | |
273 | return http_requests | |
274 | ||
275 | def parse_config(self, json_content): | |
276 | market_info = None | |
277 | for element in json_content: | |
278 | if element["type"] != "market": | |
279 | continue | |
280 | market_info = element | |
281 | assert market_info is not None, "Couldn't find market element" | |
282 | ||
283 | args = market_info["args"] | |
284 | config = [] | |
285 | for arg in ["before", "after", "quiet", "debug"]: | |
286 | if args.get(arg, False): | |
287 | config.append("--{}".format(arg)) | |
288 | for arg in ["parallel", "report_db"]: | |
289 | if not args.get(arg, False): | |
290 | config.append("--no-{}".format(arg.replace("_", "-"))) | |
a0dcf4e0 | 291 | for action in (args.get("action", []) or []): |
1d72880c IB |
292 | config.extend(["--action", action]) |
293 | if args.get("report_path") is not None: | |
294 | config.extend(["--report-path", args.get("report_path")]) | |
295 | if args.get("user") is not None: | |
296 | config.extend(["--user", args.get("user")]) | |
297 | config.extend(["--config", ""]) | |
298 | ||
299 | user = market_info["user_id"] | |
300 | date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f") | |
301 | market_id = market_info["market_id"] | |
302 | return config, user, date, market_id | |
303 | ||
304 | def requests_by_market(self): | |
305 | r = { | |
306 | None: [] | |
307 | } | |
308 | got_common = False | |
309 | for user_id, market_id, http_requests, report_lines in self.reports.values(): | |
310 | r[str(market_id)] = [] | |
311 | for http_request in http_requests: | |
312 | if http_request["market_id"] is None: | |
313 | if not got_common: | |
314 | r[None].append(http_request) | |
315 | else: | |
316 | r[str(market_id)].append(http_request) | |
317 | got_common = True | |
318 | return r | |
319 | ||
320 | def setUp(self): | |
321 | if not hasattr(self, "files"): | |
a0dcf4e0 IB |
322 | raise "This class expects to be inherited with a class defining 'files' variable" |
323 | if not hasattr(self, "log_files"): | |
324 | self.log_files = [] | |
1d72880c IB |
325 | |
326 | self.reports = {} | |
327 | self.start_date = datetime.datetime.now() | |
328 | self.config = [] | |
329 | for f in self.files: | |
330 | self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f) | |
331 | if date < self.start_date: | |
332 | self.start_date = date | |
333 | self.reports[f] = [user_id, market_id, http_requests, report_lines] | |
334 | ||
a0dcf4e0 IB |
335 | self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config) |
336 | self.requests_mock = RequestsMock(self) | |
337 | self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self) | |
338 | self.global_variables_mock = GlobalVariablesMock() | |
339 | ||
340 | self.database_mock.start() | |
341 | self.requests_mock.start() | |
342 | self.file_mock.start() | |
343 | self.global_variables_mock.start() | |
3080f31d | 344 | TimeMock.travel(self.start_date) |
1d72880c IB |
345 | |
346 | def base_test(self): | |
1d72880c | 347 | main.main(self.config) |
a0dcf4e0 IB |
348 | self.requests_mock.check_calls() |
349 | self.database_mock.check_calls() | |
350 | self.file_mock.check_calls() | |
1d72880c IB |
351 | |
352 | def tearDown(self): | |
353 | TimeMock.stop() | |
a0dcf4e0 IB |
354 | self.global_variables_mock.stop() |
355 | self.file_mock.stop() | |
356 | self.requests_mock.stop() | |
357 | self.database_mock.stop() | |
1d72880c | 358 |