]> git.immae.eu Git - perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git/blob - tests/acceptance.py
Don’t raise when some market is disabled
[perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git] / tests / acceptance.py
1 import requests
2 import requests_mock
3 import sys, os
4 import time, datetime
5 from unittest import mock
6 from ssl import SSLError
7 from decimal import Decimal
8 import simplejson as json
9 import psycopg2
10 from io import StringIO
11 import re
12 import functools
13 import threading
14
15 class TimeMock:
16 delta = {}
17 delta_init = 0
18 true_time = time.time
19 true_sleep = time.sleep
20 time_patch = None
21 datetime_patch = None
22
23 @classmethod
24 def travel(cls, start_date):
25 cls.delta = {}
26 cls.delta_init = (datetime.datetime.now() - start_date).total_seconds()
27
28 @classmethod
29 def start(cls):
30 cls.delta = {}
31 cls.delta_init = 0
32
33 class fake_datetime(datetime.datetime):
34 @classmethod
35 def now(cls, tz=None):
36 if tz is None:
37 return cls.fromtimestamp(time.time())
38 else:
39 return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz))
40
41 cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep)
42 cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime)
43 cls.time_patch.start()
44 cls.datetime_patch.start()
45
46 @classmethod
47 def stop(cls):
48 cls.delta = {}
49 cls.delta_init = 0
50
51 @classmethod
52 def fake_time(cls):
53 cls.delta.setdefault(threading.current_thread(), cls.delta_init)
54 return cls.true_time() - cls.delta[threading.current_thread()]
55
56 @classmethod
57 def fake_sleep(cls, duration):
58 cls.delta.setdefault(threading.current_thread(), cls.delta_init)
59 cls.delta[threading.current_thread()] -= float(duration)
60 cls.true_sleep(min(float(duration), 0.1))
61
62 TimeMock.start()
63 import main
64
65 class FileMock:
66 def __init__(self, log_files, quiet, tester):
67 self.tester = tester
68 self.log_files = []
69 if log_files is not None and len(log_files) > 0:
70 self.read_log_files(log_files)
71 self.quiet = quiet
72 self.patches = [
73 mock.patch("market.open"),
74 mock.patch("os.makedirs"),
75 mock.patch("sys.stdout", new_callable=StringIO),
76 ]
77 self.mocks = []
78
79 def start(self):
80 for patch in self.patches:
81 self.mocks.append(patch.start())
82 self.stdout = self.mocks[-1]
83
84 def check_calls(self):
85 stdout = self.stdout.getvalue()
86 if self.quiet:
87 self.tester.assertEqual("", stdout)
88 else:
89 log = self.strip_log(stdout)
90 if len(self.log_files) != 0:
91 split_logs = log.split("\n")
92 self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs))
93 try:
94 for log_file in self.log_files:
95 for line in log_file:
96 split_logs.pop(split_logs.index(line))
97 except ValueError:
98 if not line.startswith("[Worker] "):
99 self.tester.fail("« {} » not found in log file {}".format(line, split_logs))
100 # Le fichier de log est écrit
101 # Le rapport est écrit si pertinent
102 # Le rapport contient le bon nombre de lignes
103
104 def stop(self):
105 for patch in self.patches[::-1]:
106 patch.stop()
107 self.mocks.pop()
108
109 def strip_log(self, log):
110 log = log.replace("\n\n", "\n")
111 return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE)
112
113 def read_log_files(self, log_files):
114 for log_file in log_files:
115 with open(log_file, "r") as f:
116 log = self.strip_log(f.read()).split("\n")
117 if len(log[-1]) == 0:
118 log.pop()
119 self.log_files.append(log)
120
121 class DatabaseMock:
122 def __init__(self, tester, reports, report_db):
123 self.tester = tester
124 self.reports = reports
125 self.report_db = report_db
126 self.rows = []
127 self.total_report_lines= 0
128 self.requests = []
129 for user_id, market_id, http_requests, report_lines in self.reports.values():
130 self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) )
131 self.total_report_lines += len(report_lines)
132
133 def start(self):
134 connect_mock = mock.Mock()
135 self.cursor = mock.MagicMock()
136 connect_mock.cursor.return_value = self.cursor
137 def _execute(request, *args):
138 self.requests.append(request)
139 self.cursor.execute.side_effect = _execute
140 self.cursor.__iter__.return_value = self.rows
141
142 self.db_patch = mock.patch("psycopg2.connect")
143 self.db_mock = self.db_patch.start()
144 self.db_mock.return_value = connect_mock
145
146 def check_calls(self):
147 if self.report_db:
148 self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count)
149 self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count)
150 else:
151 self.tester.assertEqual(1, self.db_mock.call_count)
152 self.tester.assertEqual(1, self.cursor.execute.call_count)
153
154 def stop(self):
155 self.db_patch.stop()
156
157 class RequestsMock:
158 def __init__(self, tester):
159 self.tester = tester
160 self.reports = tester.requests_by_market()
161
162 self.last_https = {}
163 self.error_calls = []
164 self.mocker = requests_mock.Mocker()
165 def not_stubbed(*args):
166 self.error_calls.append([args[0].method, args[0].url])
167 raise requests_mock.exceptions.MockException("Not stubbed URL")
168 self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY,
169 text=not_stubbed)
170
171 self.mocks = {}
172
173 for market_id, elements in self.reports.items():
174 for element in elements:
175 method = element["method"]
176 url = element["url"]
177 self.mocks \
178 .setdefault((method, url), {}) \
179 .setdefault(market_id, []) \
180 .append(element)
181
182 for ((method, url), elements) in self.mocks.items():
183 self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True)
184
185 def start(self):
186 self.mocker.start()
187
188 def stop(self):
189 self.mocker.stop()
190
191 lazy_calls = [
192 "https://cryptoportfolio.io/wp-content/uploads/portfolio/json/cryptoportfolio.json",
193 "https://poloniex.com/public?command=returnTicker",
194 ]
195 def check_calls(self):
196 self.tester.assertEqual([], self.error_calls)
197 for (method, url), elements in self.mocks.items():
198 for market_id, element in elements.items():
199 if url not in self.lazy_calls:
200 self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id))
201
202 def clean_body(self, body):
203 if body is None:
204 return None
205 if isinstance(body, bytes):
206 body = body.decode()
207 body = re.sub(r"&nonce=\d*$", "", body)
208 body = re.sub(r"nonce=\d*&?", "", body)
209 return body
210
211 def callback(self, elements, request, context):
212 try:
213 element = elements[request.headers.get("X-market-id")].pop(0)
214 except (IndexError, KeyError):
215 self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")])
216 raise RuntimeError("Unexpected call")
217 if element["response"] is None and element["response_same_as"] is not None:
218 element["response"] = self.last_https[element["response_same_as"]]
219 elif element["response"] is not None:
220 self.last_https[element["date"]] = element["response"]
221
222 time.sleep(element.get("duration", 0))
223
224 assert self.clean_body(request.body) == \
225 self.clean_body(element["body"]), "Body does not match"
226 context.status_code = element["status"]
227 if "error" in element:
228 if element["error"] == "SSLError":
229 raise SSLError(element["error_message"])
230 else:
231 raise getattr(requests.exceptions, element["error"])(element["error_message"])
232 return element["response"]
233
234 class GlobalVariablesMock:
235 def start(self):
236 import market
237 import store
238
239 self.patchers = [
240 mock.patch.multiple(market.Portfolio,
241 data=store.LockedVar(None),
242 liquidities=store.LockedVar({}),
243 last_date=store.LockedVar(None),
244 report=store.LockedVar(store.ReportStore(None, no_http_dup=True)),
245 worker=None,
246 worker_tag="",
247 worker_notify=None,
248 worker_started=False,
249 callback=None)
250 ]
251 for patcher in self.patchers:
252 patcher.start()
253
254 def stop(self):
255 pass
256
257
258 class AcceptanceTestCase():
259 def parse_file(self, report_file):
260 with open(report_file, "rb") as f:
261 json_content = json.load(f, parse_float=Decimal)
262 config, user, date, market_id = self.parse_config(json_content)
263 http_requests = self.parse_requests(json_content)
264
265 return config, user, date, market_id, http_requests, json_content
266
267 def parse_requests(self, json_content):
268 http_requests = []
269 for element in json_content:
270 if element["type"] != "http_request":
271 continue
272 http_requests.append(element)
273 return http_requests
274
275 def parse_config(self, json_content):
276 market_info = None
277 for element in json_content:
278 if element["type"] != "market":
279 continue
280 market_info = element
281 assert market_info is not None, "Couldn't find market element"
282
283 args = market_info["args"]
284 config = []
285 for arg in ["before", "after", "quiet", "debug"]:
286 if args.get(arg, False):
287 config.append("--{}".format(arg))
288 for arg in ["parallel", "report_db"]:
289 if not args.get(arg, False):
290 config.append("--no-{}".format(arg.replace("_", "-")))
291 for action in (args.get("action", []) or []):
292 config.extend(["--action", action])
293 if args.get("report_path") is not None:
294 config.extend(["--report-path", args.get("report_path")])
295 if args.get("user") is not None:
296 config.extend(["--user", args.get("user")])
297 config.extend(["--config", ""])
298
299 user = market_info["user_id"]
300 date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f")
301 market_id = market_info["market_id"]
302 return config, user, date, market_id
303
304 def requests_by_market(self):
305 r = {
306 None: []
307 }
308 got_common = False
309 for user_id, market_id, http_requests, report_lines in self.reports.values():
310 r[str(market_id)] = []
311 for http_request in http_requests:
312 if http_request["market_id"] is None:
313 if not got_common:
314 r[None].append(http_request)
315 else:
316 r[str(market_id)].append(http_request)
317 got_common = True
318 return r
319
320 def setUp(self):
321 if not hasattr(self, "files"):
322 raise "This class expects to be inherited with a class defining 'files' variable"
323 if not hasattr(self, "log_files"):
324 self.log_files = []
325
326 self.reports = {}
327 self.start_date = datetime.datetime.now()
328 self.config = []
329 for f in self.files:
330 self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f)
331 if date < self.start_date:
332 self.start_date = date
333 self.reports[f] = [user_id, market_id, http_requests, report_lines]
334
335 self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config)
336 self.requests_mock = RequestsMock(self)
337 self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self)
338 self.global_variables_mock = GlobalVariablesMock()
339
340 self.database_mock.start()
341 self.requests_mock.start()
342 self.file_mock.start()
343 self.global_variables_mock.start()
344 TimeMock.travel(self.start_date)
345
346 def base_test(self):
347 main.main(self.config)
348 self.requests_mock.check_calls()
349 self.database_mock.check_calls()
350 self.file_mock.check_calls()
351
352 def tearDown(self):
353 TimeMock.stop()
354 self.global_variables_mock.stop()
355 self.file_mock.stop()
356 self.requests_mock.stop()
357 self.database_mock.stop()
358