]> git.immae.eu Git - perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git/blame - tests/acceptance.py
Merge branch 'dev'
[perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git] / tests / acceptance.py
CommitLineData
1d72880c
IB
1import requests
2import requests_mock
3import sys, os
4import time, datetime
1d72880c
IB
5from unittest import mock
6from ssl import SSLError
7from decimal import Decimal
8import simplejson as json
9import psycopg2
10from io import StringIO
a0dcf4e0
IB
11import re
12import functools
3080f31d 13import threading
a0dcf4e0 14
3080f31d
IB
15class TimeMock:
16 delta = {}
17 delta_init = 0
18 true_time = time.time
19 true_sleep = time.sleep
20 time_patch = None
21 datetime_patch = None
22
23 @classmethod
24 def travel(cls, start_date):
25 cls.delta = {}
26 cls.delta_init = (datetime.datetime.now() - start_date).total_seconds()
27
28 @classmethod
29 def start(cls):
30 cls.delta = {}
31 cls.delta_init = 0
32
33 class fake_datetime(datetime.datetime):
34 @classmethod
35 def now(cls, tz=None):
36 if tz is None:
37 return cls.fromtimestamp(time.time())
38 else:
39 return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz))
40
41 cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep)
42 cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime)
43 cls.time_patch.start()
44 cls.datetime_patch.start()
45
46 @classmethod
47 def stop(cls):
48 cls.delta = {}
49 cls.delta_init = 0
50
51 @classmethod
52 def fake_time(cls):
53 cls.delta.setdefault(threading.current_thread(), cls.delta_init)
54 return cls.true_time() - cls.delta[threading.current_thread()]
55
56 @classmethod
57 def fake_sleep(cls, duration):
58 cls.delta.setdefault(threading.current_thread(), cls.delta_init)
59 cls.delta[threading.current_thread()] -= float(duration)
60 cls.true_sleep(min(float(duration), 0.1))
61
62TimeMock.start()
a0dcf4e0 63import main
1d72880c
IB
64
65class FileMock:
a0dcf4e0
IB
66 def __init__(self, log_files, quiet, tester):
67 self.tester = tester
68 self.log_files = []
69 if log_files is not None and len(log_files) > 0:
70 self.read_log_files(log_files)
71 self.quiet = quiet
72 self.patches = [
73 mock.patch("market.open"),
74 mock.patch("os.makedirs"),
75 mock.patch("sys.stdout", new_callable=StringIO),
76 ]
77 self.mocks = []
1d72880c 78
a0dcf4e0
IB
79 def start(self):
80 for patch in self.patches:
81 self.mocks.append(patch.start())
82 self.stdout = self.mocks[-1]
1d72880c 83
a0dcf4e0
IB
84 def check_calls(self):
85 stdout = self.stdout.getvalue()
86 if self.quiet:
87 self.tester.assertEqual("", stdout)
88 else:
89 log = self.strip_log(stdout)
90 if len(self.log_files) != 0:
91 split_logs = log.split("\n")
92 self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs))
93 try:
94 for log_file in self.log_files:
95 for line in log_file:
96 split_logs.pop(split_logs.index(line))
97 except ValueError:
98 if not line.startswith("[Worker] "):
99 self.tester.fail("« {} » not found in log file {}".format(line, split_logs))
100 # Le fichier de log est écrit
a0dcf4e0
IB
101 # Le rapport est écrit si pertinent
102 # Le rapport contient le bon nombre de lignes
103
104 def stop(self):
105 for patch in self.patches[::-1]:
106 patch.stop()
107 self.mocks.pop()
1d72880c 108
a0dcf4e0
IB
109 def strip_log(self, log):
110 log = log.replace("\n\n", "\n")
111 return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE)
1d72880c 112
a0dcf4e0
IB
113 def read_log_files(self, log_files):
114 for log_file in log_files:
115 with open(log_file, "r") as f:
116 log = self.strip_log(f.read()).split("\n")
117 if len(log[-1]) == 0:
118 log.pop()
119 self.log_files.append(log)
120
121class DatabaseMock:
122 def __init__(self, tester, reports, report_db):
123 self.tester = tester
124 self.reports = reports
125 self.report_db = report_db
126 self.rows = []
127 self.total_report_lines= 0
128 self.requests = []
129 for user_id, market_id, http_requests, report_lines in self.reports.values():
130 self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) )
131 self.total_report_lines += len(report_lines)
1d72880c 132
a0dcf4e0 133 def start(self):
1d72880c 134 connect_mock = mock.Mock()
a0dcf4e0
IB
135 self.cursor = mock.MagicMock()
136 connect_mock.cursor.return_value = self.cursor
1d72880c 137 def _execute(request, *args):
a0dcf4e0
IB
138 self.requests.append(request)
139 self.cursor.execute.side_effect = _execute
140 self.cursor.__iter__.return_value = self.rows
141
142 self.db_patch = mock.patch("psycopg2.connect")
143 self.db_mock = self.db_patch.start()
144 self.db_mock.return_value = connect_mock
145
146 def check_calls(self):
147 if self.report_db:
148 self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count)
149 self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count)
1d72880c 150 else:
a0dcf4e0
IB
151 self.tester.assertEqual(1, self.db_mock.call_count)
152 self.tester.assertEqual(1, self.cursor.execute.call_count)
1d72880c 153
a0dcf4e0
IB
154 def stop(self):
155 self.db_patch.stop()
1d72880c
IB
156
157class RequestsMock:
a0dcf4e0
IB
158 def __init__(self, tester):
159 self.tester = tester
160 self.reports = tester.requests_by_market()
161
162 self.last_https = {}
163 self.error_calls = []
164 self.mocker = requests_mock.Mocker()
165 def not_stubbed(*args):
166 self.error_calls.append([args[0].method, args[0].url])
167 raise requests_mock.exceptions.MockException("Not stubbed URL")
168 self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY,
169 text=not_stubbed)
170
171 self.mocks = {}
172
173 for market_id, elements in self.reports.items():
1d72880c
IB
174 for element in elements:
175 method = element["method"]
176 url = element["url"]
a0dcf4e0 177 self.mocks \
1d72880c
IB
178 .setdefault((method, url), {}) \
179 .setdefault(market_id, []) \
180 .append(element)
181
a0dcf4e0
IB
182 for ((method, url), elements) in self.mocks.items():
183 self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True)
1d72880c 184
a0dcf4e0
IB
185 def start(self):
186 self.mocker.start()
1d72880c 187
a0dcf4e0
IB
188 def stop(self):
189 self.mocker.stop()
190
3080f31d
IB
191 lazy_calls = [
192 "https://cryptoportfolio.io/wp-content/uploads/portfolio/json/cryptoportfolio.json",
193 "https://poloniex.com/public?command=returnTicker",
194 ]
a0dcf4e0
IB
195 def check_calls(self):
196 self.tester.assertEqual([], self.error_calls)
197 for (method, url), elements in self.mocks.items():
1d72880c 198 for market_id, element in elements.items():
3080f31d
IB
199 if url not in self.lazy_calls:
200 self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id))
1d72880c 201
a0dcf4e0 202 def clean_body(self, body):
1d72880c
IB
203 if body is None:
204 return None
1d72880c
IB
205 if isinstance(body, bytes):
206 body = body.decode()
207 body = re.sub(r"&nonce=\d*$", "", body)
208 body = re.sub(r"nonce=\d*&?", "", body)
209 return body
210
a0dcf4e0
IB
211def callback(self, elements, request, context):
212 try:
213 element = elements[request.headers.get("X-market-id")].pop(0)
214 except (IndexError, KeyError):
215 self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")])
216 raise RuntimeError("Unexpected call")
217 if element["response"] is None and element["response_same_as"] is not None:
218 element["response"] = self.last_https[element["response_same_as"]]
219 elif element["response"] is not None:
220 self.last_https[element["date"]] = element["response"]
221
3080f31d
IB
222 time.sleep(element.get("duration", 0))
223
a0dcf4e0
IB
224 assert self.clean_body(request.body) == \
225 self.clean_body(element["body"]), "Body does not match"
226 context.status_code = element["status"]
227 if "error" in element:
228 if element["error"] == "SSLError":
229 raise SSLError(element["error_message"])
230 else:
231 raise getattr(requests.exceptions, element["error"])(element["error_message"])
232 return element["response"]
233
234class GlobalVariablesMock:
235 def start(self):
236 import market
237 import store
238
239 self.patchers = [
240 mock.patch.multiple(market.Portfolio,
241 data=store.LockedVar(None),
242 liquidities=store.LockedVar({}),
243 last_date=store.LockedVar(None),
244 report=store.LockedVar(store.ReportStore(None, no_http_dup=True)),
245 worker=None,
246 worker_tag="",
247 worker_notify=None,
248 worker_started=False,
249 callback=None)
250 ]
251 for patcher in self.patchers:
252 patcher.start()
253
254 def stop(self):
255 pass
256
1d72880c 257
1d72880c
IB
258class AcceptanceTestCase():
259 def parse_file(self, report_file):
260 with open(report_file, "rb") as f:
261 json_content = json.load(f, parse_float=Decimal)
262 config, user, date, market_id = self.parse_config(json_content)
263 http_requests = self.parse_requests(json_content)
264
265 return config, user, date, market_id, http_requests, json_content
266
267 def parse_requests(self, json_content):
268 http_requests = []
269 for element in json_content:
270 if element["type"] != "http_request":
271 continue
272 http_requests.append(element)
273 return http_requests
274
275 def parse_config(self, json_content):
276 market_info = None
277 for element in json_content:
278 if element["type"] != "market":
279 continue
280 market_info = element
281 assert market_info is not None, "Couldn't find market element"
282
283 args = market_info["args"]
284 config = []
285 for arg in ["before", "after", "quiet", "debug"]:
286 if args.get(arg, False):
287 config.append("--{}".format(arg))
288 for arg in ["parallel", "report_db"]:
289 if not args.get(arg, False):
290 config.append("--no-{}".format(arg.replace("_", "-")))
a0dcf4e0 291 for action in (args.get("action", []) or []):
1d72880c
IB
292 config.extend(["--action", action])
293 if args.get("report_path") is not None:
294 config.extend(["--report-path", args.get("report_path")])
295 if args.get("user") is not None:
296 config.extend(["--user", args.get("user")])
297 config.extend(["--config", ""])
298
299 user = market_info["user_id"]
300 date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f")
301 market_id = market_info["market_id"]
302 return config, user, date, market_id
303
304 def requests_by_market(self):
305 r = {
306 None: []
307 }
308 got_common = False
309 for user_id, market_id, http_requests, report_lines in self.reports.values():
310 r[str(market_id)] = []
311 for http_request in http_requests:
312 if http_request["market_id"] is None:
313 if not got_common:
314 r[None].append(http_request)
315 else:
316 r[str(market_id)].append(http_request)
317 got_common = True
318 return r
319
320 def setUp(self):
321 if not hasattr(self, "files"):
a0dcf4e0
IB
322 raise "This class expects to be inherited with a class defining 'files' variable"
323 if not hasattr(self, "log_files"):
324 self.log_files = []
1d72880c
IB
325
326 self.reports = {}
327 self.start_date = datetime.datetime.now()
328 self.config = []
329 for f in self.files:
330 self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f)
331 if date < self.start_date:
332 self.start_date = date
333 self.reports[f] = [user_id, market_id, http_requests, report_lines]
334
a0dcf4e0
IB
335 self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config)
336 self.requests_mock = RequestsMock(self)
337 self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self)
338 self.global_variables_mock = GlobalVariablesMock()
339
340 self.database_mock.start()
341 self.requests_mock.start()
342 self.file_mock.start()
343 self.global_variables_mock.start()
3080f31d 344 TimeMock.travel(self.start_date)
1d72880c
IB
345
346 def base_test(self):
1d72880c 347 main.main(self.config)
a0dcf4e0
IB
348 self.requests_mock.check_calls()
349 self.database_mock.check_calls()
350 self.file_mock.check_calls()
1d72880c
IB
351
352 def tearDown(self):
353 TimeMock.stop()
a0dcf4e0
IB
354 self.global_variables_mock.stop()
355 self.file_mock.stop()
356 self.requests_mock.stop()
357 self.database_mock.stop()
1d72880c 358