]> git.immae.eu Git - perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git/blob - test_acceptance.py
Store duration in http requests
[perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git] / test_acceptance.py
1 import requests
2 import requests_mock
3 import sys, os
4 import time, datetime
5 import unittest
6 from unittest import mock
7 from ssl import SSLError
8 from decimal import Decimal
9 import simplejson as json
10 import psycopg2
11 from io import StringIO
12 import re
13 import functools
14 import glob
15
16 import main
17
18 class FileMock:
19 def __init__(self, log_files, quiet, tester):
20 self.tester = tester
21 self.log_files = []
22 if log_files is not None and len(log_files) > 0:
23 self.read_log_files(log_files)
24 self.quiet = quiet
25 self.patches = [
26 mock.patch("market.open"),
27 mock.patch("os.makedirs"),
28 mock.patch("sys.stdout", new_callable=StringIO),
29 ]
30 self.mocks = []
31
32 def start(self):
33 for patch in self.patches:
34 self.mocks.append(patch.start())
35 self.stdout = self.mocks[-1]
36
37 def check_calls(self):
38 stdout = self.stdout.getvalue()
39 if self.quiet:
40 self.tester.assertEqual("", stdout)
41 else:
42 log = self.strip_log(stdout)
43 if len(self.log_files) != 0:
44 split_logs = log.split("\n")
45 self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs))
46 try:
47 for log_file in self.log_files:
48 for line in log_file:
49 split_logs.pop(split_logs.index(line))
50 except ValueError:
51 if not line.startswith("[Worker] "):
52 self.tester.fail("« {} » not found in log file {}".format(line, split_logs))
53 # Le fichier de log est écrit
54 # Le fichier de log est printed uniquement si non quiet
55 # Le rapport est écrit si pertinent
56 # Le rapport contient le bon nombre de lignes
57
58 def stop(self):
59 for patch in self.patches[::-1]:
60 patch.stop()
61 self.mocks.pop()
62
63 def strip_log(self, log):
64 log = log.replace("\n\n", "\n")
65 return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE)
66
67 def read_log_files(self, log_files):
68 for log_file in log_files:
69 with open(log_file, "r") as f:
70 log = self.strip_log(f.read()).split("\n")
71 if len(log[-1]) == 0:
72 log.pop()
73 self.log_files.append(log)
74
75 class DatabaseMock:
76 def __init__(self, tester, reports, report_db):
77 self.tester = tester
78 self.reports = reports
79 self.report_db = report_db
80 self.rows = []
81 self.total_report_lines= 0
82 self.requests = []
83 for user_id, market_id, http_requests, report_lines in self.reports.values():
84 self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) )
85 self.total_report_lines += len(report_lines)
86
87 def start(self):
88 connect_mock = mock.Mock()
89 self.cursor = mock.MagicMock()
90 connect_mock.cursor.return_value = self.cursor
91 def _execute(request, *args):
92 self.requests.append(request)
93 self.cursor.execute.side_effect = _execute
94 self.cursor.__iter__.return_value = self.rows
95
96 self.db_patch = mock.patch("psycopg2.connect")
97 self.db_mock = self.db_patch.start()
98 self.db_mock.return_value = connect_mock
99
100 def check_calls(self):
101 if self.report_db:
102 self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count)
103 self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count)
104 else:
105 self.tester.assertEqual(1, self.db_mock.call_count)
106 self.tester.assertEqual(1, self.cursor.execute.call_count)
107
108 def stop(self):
109 self.db_patch.stop()
110
111 class RequestsMock:
112 def __init__(self, tester):
113 self.tester = tester
114 self.reports = tester.requests_by_market()
115
116 self.last_https = {}
117 self.error_calls = []
118 self.mocker = requests_mock.Mocker()
119 def not_stubbed(*args):
120 self.error_calls.append([args[0].method, args[0].url])
121 raise requests_mock.exceptions.MockException("Not stubbed URL")
122 self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY,
123 text=not_stubbed)
124
125 self.mocks = {}
126
127 for market_id, elements in self.reports.items():
128 for element in elements:
129 method = element["method"]
130 url = element["url"]
131 self.mocks \
132 .setdefault((method, url), {}) \
133 .setdefault(market_id, []) \
134 .append(element)
135
136 for ((method, url), elements) in self.mocks.items():
137 self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True)
138
139 def start(self):
140 self.mocker.start()
141
142 def stop(self):
143 self.mocker.stop()
144
145 def check_calls(self):
146 self.tester.assertEqual([], self.error_calls)
147 for (method, url), elements in self.mocks.items():
148 for market_id, element in elements.items():
149 self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id))
150
151 def clean_body(self, body):
152 if body is None:
153 return None
154 if isinstance(body, bytes):
155 body = body.decode()
156 body = re.sub(r"&nonce=\d*$", "", body)
157 body = re.sub(r"nonce=\d*&?", "", body)
158 return body
159
160 def callback(self, elements, request, context):
161 try:
162 element = elements[request.headers.get("X-market-id")].pop(0)
163 except (IndexError, KeyError):
164 self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")])
165 raise RuntimeError("Unexpected call")
166 if element["response"] is None and element["response_same_as"] is not None:
167 element["response"] = self.last_https[element["response_same_as"]]
168 elif element["response"] is not None:
169 self.last_https[element["date"]] = element["response"]
170
171 assert self.clean_body(request.body) == \
172 self.clean_body(element["body"]), "Body does not match"
173 context.status_code = element["status"]
174 if "error" in element:
175 if element["error"] == "SSLError":
176 raise SSLError(element["error_message"])
177 else:
178 raise getattr(requests.exceptions, element["error"])(element["error_message"])
179 return element["response"]
180
181 class GlobalVariablesMock:
182 def start(self):
183 import market
184 import store
185
186 self.patchers = [
187 mock.patch.multiple(market.Portfolio,
188 data=store.LockedVar(None),
189 liquidities=store.LockedVar({}),
190 last_date=store.LockedVar(None),
191 report=store.LockedVar(store.ReportStore(None, no_http_dup=True)),
192 worker=None,
193 worker_tag="",
194 worker_notify=None,
195 worker_started=False,
196 callback=None)
197 ]
198 for patcher in self.patchers:
199 patcher.start()
200
201 def stop(self):
202 pass
203
204
205 class TimeMock:
206 delta = 0
207 true_time = time.time
208 true_sleep = time.sleep
209 time_patch = None
210 datetime_patch = None
211
212 @classmethod
213 def start(cls, start_date):
214 cls.delta = (datetime.datetime.now() - start_date).total_seconds()
215
216 class fake_datetime(datetime.datetime):
217 @classmethod
218 def now(cls, tz=None):
219 if tz is None:
220 return cls.fromtimestamp(time.time())
221 else:
222 return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz))
223
224 cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep)
225 cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime)
226 cls.time_patch.start()
227 cls.datetime_patch.start()
228
229 @classmethod
230 def stop(cls):
231 cls.delta = 0
232 cls.datetime_patch.stop()
233 cls.time_patch.stop()
234
235 @classmethod
236 def fake_time(cls):
237 return cls.true_time() - cls.delta
238
239 @classmethod
240 def fake_sleep(cls, duration):
241 cls.delta -= duration
242 cls.true_sleep(0.2)
243
244 class AcceptanceTestCase():
245 def parse_file(self, report_file):
246 with open(report_file, "rb") as f:
247 json_content = json.load(f, parse_float=Decimal)
248 config, user, date, market_id = self.parse_config(json_content)
249 http_requests = self.parse_requests(json_content)
250
251 return config, user, date, market_id, http_requests, json_content
252
253 def parse_requests(self, json_content):
254 http_requests = []
255 for element in json_content:
256 if element["type"] != "http_request":
257 continue
258 http_requests.append(element)
259 return http_requests
260
261 def parse_config(self, json_content):
262 market_info = None
263 for element in json_content:
264 if element["type"] != "market":
265 continue
266 market_info = element
267 assert market_info is not None, "Couldn't find market element"
268
269 args = market_info["args"]
270 config = []
271 for arg in ["before", "after", "quiet", "debug"]:
272 if args.get(arg, False):
273 config.append("--{}".format(arg))
274 for arg in ["parallel", "report_db"]:
275 if not args.get(arg, False):
276 config.append("--no-{}".format(arg.replace("_", "-")))
277 for action in (args.get("action", []) or []):
278 config.extend(["--action", action])
279 if args.get("report_path") is not None:
280 config.extend(["--report-path", args.get("report_path")])
281 if args.get("user") is not None:
282 config.extend(["--user", args.get("user")])
283 config.extend(["--config", ""])
284
285 user = market_info["user_id"]
286 date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f")
287 market_id = market_info["market_id"]
288 return config, user, date, market_id
289
290 def requests_by_market(self):
291 r = {
292 None: []
293 }
294 got_common = False
295 for user_id, market_id, http_requests, report_lines in self.reports.values():
296 r[str(market_id)] = []
297 for http_request in http_requests:
298 if http_request["market_id"] is None:
299 if not got_common:
300 r[None].append(http_request)
301 else:
302 r[str(market_id)].append(http_request)
303 got_common = True
304 return r
305
306 def setUp(self):
307 if not hasattr(self, "files"):
308 raise "This class expects to be inherited with a class defining 'files' variable"
309 if not hasattr(self, "log_files"):
310 self.log_files = []
311
312 self.reports = {}
313 self.start_date = datetime.datetime.now()
314 self.config = []
315 for f in self.files:
316 self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f)
317 if date < self.start_date:
318 self.start_date = date
319 self.reports[f] = [user_id, market_id, http_requests, report_lines]
320
321 self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config)
322 self.requests_mock = RequestsMock(self)
323 self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self)
324 self.global_variables_mock = GlobalVariablesMock()
325
326 self.database_mock.start()
327 self.requests_mock.start()
328 self.file_mock.start()
329 self.global_variables_mock.start()
330 TimeMock.start(self.start_date)
331
332 def base_test(self):
333 main.main(self.config)
334 self.requests_mock.check_calls()
335 self.database_mock.check_calls()
336 self.file_mock.check_calls()
337
338 def tearDown(self):
339 TimeMock.stop()
340 self.global_variables_mock.stop()
341 self.file_mock.stop()
342 self.requests_mock.stop()
343 self.database_mock.stop()
344
345 for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True):
346 json_files = glob.glob("{}/*.json".format(dirfile))
347 log_files = glob.glob("{}/*.log".format(dirfile))
348 if len(json_files) > 0:
349 name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1]
350 cname = "".join(list(map(lambda x: x.capitalize(), name.split("_"))))
351
352 globals()[cname] = type(cname,
353 (AcceptanceTestCase,unittest.TestCase), {
354 "log_files": log_files,
355 "files": json_files,
356 "test_{}".format(name): AcceptanceTestCase.base_test
357 })
358
359 if __name__ == '__main__':
360 unittest.main()
361