]> git.immae.eu Git - perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git/blame - test_acceptance.py
Store duration in http requests
[perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git] / test_acceptance.py
CommitLineData
1d72880c
IB
1import requests
2import requests_mock
3import sys, os
4import time, datetime
5import unittest
6from unittest import mock
7from ssl import SSLError
8from decimal import Decimal
9import simplejson as json
10import psycopg2
11from io import StringIO
a0dcf4e0
IB
12import re
13import functools
14import glob
15
16import main
1d72880c
IB
17
18class FileMock:
a0dcf4e0
IB
19 def __init__(self, log_files, quiet, tester):
20 self.tester = tester
21 self.log_files = []
22 if log_files is not None and len(log_files) > 0:
23 self.read_log_files(log_files)
24 self.quiet = quiet
25 self.patches = [
26 mock.patch("market.open"),
27 mock.patch("os.makedirs"),
28 mock.patch("sys.stdout", new_callable=StringIO),
29 ]
30 self.mocks = []
1d72880c 31
a0dcf4e0
IB
32 def start(self):
33 for patch in self.patches:
34 self.mocks.append(patch.start())
35 self.stdout = self.mocks[-1]
1d72880c 36
a0dcf4e0
IB
37 def check_calls(self):
38 stdout = self.stdout.getvalue()
39 if self.quiet:
40 self.tester.assertEqual("", stdout)
41 else:
42 log = self.strip_log(stdout)
43 if len(self.log_files) != 0:
44 split_logs = log.split("\n")
45 self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs))
46 try:
47 for log_file in self.log_files:
48 for line in log_file:
49 split_logs.pop(split_logs.index(line))
50 except ValueError:
51 if not line.startswith("[Worker] "):
52 self.tester.fail("« {} » not found in log file {}".format(line, split_logs))
53 # Le fichier de log est écrit
54 # Le fichier de log est printed uniquement si non quiet
55 # Le rapport est écrit si pertinent
56 # Le rapport contient le bon nombre de lignes
57
58 def stop(self):
59 for patch in self.patches[::-1]:
60 patch.stop()
61 self.mocks.pop()
1d72880c 62
a0dcf4e0
IB
63 def strip_log(self, log):
64 log = log.replace("\n\n", "\n")
65 return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE)
1d72880c 66
a0dcf4e0
IB
67 def read_log_files(self, log_files):
68 for log_file in log_files:
69 with open(log_file, "r") as f:
70 log = self.strip_log(f.read()).split("\n")
71 if len(log[-1]) == 0:
72 log.pop()
73 self.log_files.append(log)
74
75class DatabaseMock:
76 def __init__(self, tester, reports, report_db):
77 self.tester = tester
78 self.reports = reports
79 self.report_db = report_db
80 self.rows = []
81 self.total_report_lines= 0
82 self.requests = []
83 for user_id, market_id, http_requests, report_lines in self.reports.values():
84 self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) )
85 self.total_report_lines += len(report_lines)
1d72880c 86
a0dcf4e0 87 def start(self):
1d72880c 88 connect_mock = mock.Mock()
a0dcf4e0
IB
89 self.cursor = mock.MagicMock()
90 connect_mock.cursor.return_value = self.cursor
1d72880c 91 def _execute(request, *args):
a0dcf4e0
IB
92 self.requests.append(request)
93 self.cursor.execute.side_effect = _execute
94 self.cursor.__iter__.return_value = self.rows
95
96 self.db_patch = mock.patch("psycopg2.connect")
97 self.db_mock = self.db_patch.start()
98 self.db_mock.return_value = connect_mock
99
100 def check_calls(self):
101 if self.report_db:
102 self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count)
103 self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count)
1d72880c 104 else:
a0dcf4e0
IB
105 self.tester.assertEqual(1, self.db_mock.call_count)
106 self.tester.assertEqual(1, self.cursor.execute.call_count)
1d72880c 107
a0dcf4e0
IB
108 def stop(self):
109 self.db_patch.stop()
1d72880c
IB
110
111class RequestsMock:
a0dcf4e0
IB
112 def __init__(self, tester):
113 self.tester = tester
114 self.reports = tester.requests_by_market()
115
116 self.last_https = {}
117 self.error_calls = []
118 self.mocker = requests_mock.Mocker()
119 def not_stubbed(*args):
120 self.error_calls.append([args[0].method, args[0].url])
121 raise requests_mock.exceptions.MockException("Not stubbed URL")
122 self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY,
123 text=not_stubbed)
124
125 self.mocks = {}
126
127 for market_id, elements in self.reports.items():
1d72880c
IB
128 for element in elements:
129 method = element["method"]
130 url = element["url"]
a0dcf4e0 131 self.mocks \
1d72880c
IB
132 .setdefault((method, url), {}) \
133 .setdefault(market_id, []) \
134 .append(element)
135
a0dcf4e0
IB
136 for ((method, url), elements) in self.mocks.items():
137 self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True)
1d72880c 138
a0dcf4e0
IB
139 def start(self):
140 self.mocker.start()
1d72880c 141
a0dcf4e0
IB
142 def stop(self):
143 self.mocker.stop()
144
145 def check_calls(self):
146 self.tester.assertEqual([], self.error_calls)
147 for (method, url), elements in self.mocks.items():
1d72880c 148 for market_id, element in elements.items():
a0dcf4e0 149 self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id))
1d72880c 150
a0dcf4e0 151 def clean_body(self, body):
1d72880c
IB
152 if body is None:
153 return None
1d72880c
IB
154 if isinstance(body, bytes):
155 body = body.decode()
156 body = re.sub(r"&nonce=\d*$", "", body)
157 body = re.sub(r"nonce=\d*&?", "", body)
158 return body
159
a0dcf4e0
IB
160def callback(self, elements, request, context):
161 try:
162 element = elements[request.headers.get("X-market-id")].pop(0)
163 except (IndexError, KeyError):
164 self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")])
165 raise RuntimeError("Unexpected call")
166 if element["response"] is None and element["response_same_as"] is not None:
167 element["response"] = self.last_https[element["response_same_as"]]
168 elif element["response"] is not None:
169 self.last_https[element["date"]] = element["response"]
170
171 assert self.clean_body(request.body) == \
172 self.clean_body(element["body"]), "Body does not match"
173 context.status_code = element["status"]
174 if "error" in element:
175 if element["error"] == "SSLError":
176 raise SSLError(element["error_message"])
177 else:
178 raise getattr(requests.exceptions, element["error"])(element["error_message"])
179 return element["response"]
180
181class GlobalVariablesMock:
182 def start(self):
183 import market
184 import store
185
186 self.patchers = [
187 mock.patch.multiple(market.Portfolio,
188 data=store.LockedVar(None),
189 liquidities=store.LockedVar({}),
190 last_date=store.LockedVar(None),
191 report=store.LockedVar(store.ReportStore(None, no_http_dup=True)),
192 worker=None,
193 worker_tag="",
194 worker_notify=None,
195 worker_started=False,
196 callback=None)
197 ]
198 for patcher in self.patchers:
199 patcher.start()
200
201 def stop(self):
202 pass
203
1d72880c
IB
204
205class TimeMock:
206 delta = 0
207 true_time = time.time
a0dcf4e0 208 true_sleep = time.sleep
1d72880c
IB
209 time_patch = None
210 datetime_patch = None
211
212 @classmethod
213 def start(cls, start_date):
214 cls.delta = (datetime.datetime.now() - start_date).total_seconds()
215
216 class fake_datetime(datetime.datetime):
217 @classmethod
218 def now(cls, tz=None):
219 if tz is None:
220 return cls.fromtimestamp(time.time())
221 else:
222 return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz))
223
224 cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep)
225 cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime)
226 cls.time_patch.start()
227 cls.datetime_patch.start()
228
229 @classmethod
230 def stop(cls):
231 cls.delta = 0
232 cls.datetime_patch.stop()
233 cls.time_patch.stop()
234
235 @classmethod
236 def fake_time(cls):
237 return cls.true_time() - cls.delta
238
239 @classmethod
240 def fake_sleep(cls, duration):
241 cls.delta -= duration
a0dcf4e0 242 cls.true_sleep(0.2)
1d72880c
IB
243
244class AcceptanceTestCase():
245 def parse_file(self, report_file):
246 with open(report_file, "rb") as f:
247 json_content = json.load(f, parse_float=Decimal)
248 config, user, date, market_id = self.parse_config(json_content)
249 http_requests = self.parse_requests(json_content)
250
251 return config, user, date, market_id, http_requests, json_content
252
253 def parse_requests(self, json_content):
254 http_requests = []
255 for element in json_content:
256 if element["type"] != "http_request":
257 continue
258 http_requests.append(element)
259 return http_requests
260
261 def parse_config(self, json_content):
262 market_info = None
263 for element in json_content:
264 if element["type"] != "market":
265 continue
266 market_info = element
267 assert market_info is not None, "Couldn't find market element"
268
269 args = market_info["args"]
270 config = []
271 for arg in ["before", "after", "quiet", "debug"]:
272 if args.get(arg, False):
273 config.append("--{}".format(arg))
274 for arg in ["parallel", "report_db"]:
275 if not args.get(arg, False):
276 config.append("--no-{}".format(arg.replace("_", "-")))
a0dcf4e0 277 for action in (args.get("action", []) or []):
1d72880c
IB
278 config.extend(["--action", action])
279 if args.get("report_path") is not None:
280 config.extend(["--report-path", args.get("report_path")])
281 if args.get("user") is not None:
282 config.extend(["--user", args.get("user")])
283 config.extend(["--config", ""])
284
285 user = market_info["user_id"]
286 date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f")
287 market_id = market_info["market_id"]
288 return config, user, date, market_id
289
290 def requests_by_market(self):
291 r = {
292 None: []
293 }
294 got_common = False
295 for user_id, market_id, http_requests, report_lines in self.reports.values():
296 r[str(market_id)] = []
297 for http_request in http_requests:
298 if http_request["market_id"] is None:
299 if not got_common:
300 r[None].append(http_request)
301 else:
302 r[str(market_id)].append(http_request)
303 got_common = True
304 return r
305
306 def setUp(self):
307 if not hasattr(self, "files"):
a0dcf4e0
IB
308 raise "This class expects to be inherited with a class defining 'files' variable"
309 if not hasattr(self, "log_files"):
310 self.log_files = []
1d72880c
IB
311
312 self.reports = {}
313 self.start_date = datetime.datetime.now()
314 self.config = []
315 for f in self.files:
316 self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f)
317 if date < self.start_date:
318 self.start_date = date
319 self.reports[f] = [user_id, market_id, http_requests, report_lines]
320
a0dcf4e0
IB
321 self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config)
322 self.requests_mock = RequestsMock(self)
323 self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self)
324 self.global_variables_mock = GlobalVariablesMock()
325
326 self.database_mock.start()
327 self.requests_mock.start()
328 self.file_mock.start()
329 self.global_variables_mock.start()
1d72880c
IB
330 TimeMock.start(self.start_date)
331
332 def base_test(self):
1d72880c 333 main.main(self.config)
a0dcf4e0
IB
334 self.requests_mock.check_calls()
335 self.database_mock.check_calls()
336 self.file_mock.check_calls()
1d72880c
IB
337
338 def tearDown(self):
339 TimeMock.stop()
a0dcf4e0
IB
340 self.global_variables_mock.stop()
341 self.file_mock.stop()
342 self.requests_mock.stop()
343 self.database_mock.stop()
1d72880c 344
1d72880c
IB
345for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True):
346 json_files = glob.glob("{}/*.json".format(dirfile))
a0dcf4e0 347 log_files = glob.glob("{}/*.log".format(dirfile))
1d72880c
IB
348 if len(json_files) > 0:
349 name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1]
350 cname = "".join(list(map(lambda x: x.capitalize(), name.split("_"))))
351
352 globals()[cname] = type(cname,
353 (AcceptanceTestCase,unittest.TestCase), {
a0dcf4e0 354 "log_files": log_files,
1d72880c
IB
355 "files": json_files,
356 "test_{}".format(name): AcceptanceTestCase.base_test
357 })
358
359if __name__ == '__main__':
360 unittest.main()
361