]> git.immae.eu Git - perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git/blob - test_acceptance.py
Add acceptance tests
[perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git] / test_acceptance.py
1 import requests
2 import requests_mock
3 import sys, os
4 import time, datetime
5 import unittest
6 from unittest import mock
7 from ssl import SSLError
8 from decimal import Decimal
9 import simplejson as json
10 import psycopg2
11 from io import StringIO
12
13 class FileMock:
14 @classmethod
15 def start(cls):
16 cls.file_mock = mock.patch("market.open")
17 cls.os_mock = mock.patch("os.makedirs")
18 cls.stdout_mock = mock.patch('sys.stdout', new_callable=StringIO)
19 cls.stdout_mock.start()
20 cls.os_mock.start()
21 cls.file_mock.start()
22
23 @classmethod
24 def check_calls(cls, tester):
25 pass
26 #raise NotImplementedError("Todo")
27
28 @classmethod
29 def stop(cls):
30 cls.file_mock.stop()
31 cls.stdout_mock.stop()
32 cls.os_mock.stop()
33
34 class DatabaseMock:
35 rows = []
36 report_db = False
37 db_patch = None
38 cursor = None
39 total_report_lines = 0
40 db_mock = None
41 requests = []
42
43 @classmethod
44 def start(cls, reports, report_db):
45 cls.report_db = report_db
46 cls.rows = []
47 cls.total_report_lines= 0
48 for user_id, market_id, http_requests, report_lines in reports.values():
49 cls.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) )
50 cls.total_report_lines += len(report_lines)
51
52 connect_mock = mock.Mock()
53 cls.cursor = mock.MagicMock()
54 connect_mock.cursor.return_value = cls.cursor
55 def _execute(request, *args):
56 cls.requests.append(request)
57 cls.cursor.execute.side_effect = _execute
58 cls.cursor.__iter__.return_value = cls.rows
59
60 cls.db_patch = mock.patch("psycopg2.connect")
61 cls.db_mock = cls.db_patch.start()
62 cls.db_mock.return_value = connect_mock
63
64 @classmethod
65 def check_calls(cls, tester):
66 if cls.report_db:
67 tester.assertEqual(1 + len(cls.rows), cls.db_mock.call_count)
68 tester.assertEqual(1 + len(cls.rows) + cls.total_report_lines, cls.cursor.execute.call_count)
69 else:
70 tester.assertEqual(1, cls.db_mock.call_count)
71 tester.assertEqual(1, cls.cursor.execute.call_count)
72
73 @classmethod
74 def stop(cls):
75 cls.db_patch.stop()
76 cls.db_mock = None
77 cls.cursor = None
78 cls.total_report_lines = 0
79 cls.rows = []
80 cls.report_db = False
81 cls.requests = []
82
83 class RequestsMock:
84 adapter = None
85 mocks = {}
86 last_https = {}
87 request_patch = []
88
89 @classmethod
90 def start(cls, reports):
91 cls.adapter = requests_mock.Adapter()
92 cls.adapter.register_uri(requests_mock.ANY, requests_mock.ANY,
93 exc=requests_mock.exceptions.MockException("Not stubbed URL"))
94 true_session = requests.Session
95
96 cls.mocks = {}
97
98 for market_id, elements in reports.items():
99 for element in elements:
100 method = element["method"]
101 url = element["url"]
102 cls.mocks \
103 .setdefault((method, url), {}) \
104 .setdefault(market_id, []) \
105 .append(element)
106
107 for ((method, url), elements) in cls.mocks.items():
108 cls.adapter.register_uri(method, url, text=cls.callback_func(elements), complete_qs=True)
109 def _session():
110 session = true_session()
111 session.get_adapter = lambda url: cls.adapter
112 return session
113 cls.request_patch = [
114 mock.patch.object(requests.sessions, "Session", new=_session),
115 mock.patch.object(requests, "Session", new=_session)
116 ]
117 for patch in cls.request_patch:
118 patch.start()
119
120 @classmethod
121 def stop(cls):
122 for patch in cls.request_patch:
123 patch.stop()
124 cls.request_patch = []
125 cls.last_https = {}
126 cls.mocks = {}
127 cls.adapter = None
128
129 @classmethod
130 def check_calls(cls, tester):
131 for (method, url), elements in cls.mocks.items():
132 for market_id, element in elements.items():
133 tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id))
134
135 @classmethod
136 def clean_body(cls, body):
137 if body is None:
138 return None
139 import re
140 if isinstance(body, bytes):
141 body = body.decode()
142 body = re.sub(r"&nonce=\d*$", "", body)
143 body = re.sub(r"nonce=\d*&?", "", body)
144 return body
145
146 @classmethod
147 def callback_func(cls, elements):
148 def callback(request, context):
149 try:
150 element = elements[request.headers.get("X-market-id")].pop(0)
151 except (IndexError, KeyError):
152 raise RuntimeError("Unexpected call")
153 if element["response"] is None and element["response_same_as"] is not None:
154 element["response"] = cls.last_https[element["response_same_as"]]
155 elif element["response"] is not None:
156 cls.last_https[element["date"]] = element["response"]
157
158 assert cls.clean_body(request.body) == \
159 cls.clean_body(element["body"]), "Body does not match"
160 context.status_code = element["status"]
161 if "error" in element:
162 if element["error"] == "SSLError":
163 raise SSLError(element["error_message"])
164 else:
165 raise getattr(requests.exceptions, element["error"])(element["error_message"])
166 return element["response"]
167 return callback
168
169 class TimeMock:
170 delta = 0
171 true_time = time.time
172 time_patch = None
173 datetime_patch = None
174
175 @classmethod
176 def start(cls, start_date):
177 cls.delta = (datetime.datetime.now() - start_date).total_seconds()
178
179 class fake_datetime(datetime.datetime):
180 @classmethod
181 def now(cls, tz=None):
182 if tz is None:
183 return cls.fromtimestamp(time.time())
184 else:
185 return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz))
186
187 cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep)
188 cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime)
189 cls.time_patch.start()
190 cls.datetime_patch.start()
191
192 @classmethod
193 def stop(cls):
194 cls.delta = 0
195 cls.datetime_patch.stop()
196 cls.time_patch.stop()
197
198 @classmethod
199 def fake_time(cls):
200 return cls.true_time() - cls.delta
201
202 @classmethod
203 def fake_sleep(cls, duration):
204 cls.delta -= duration
205
206 class AcceptanceTestCase():
207 def parse_file(self, report_file):
208 with open(report_file, "rb") as f:
209 json_content = json.load(f, parse_float=Decimal)
210 config, user, date, market_id = self.parse_config(json_content)
211 http_requests = self.parse_requests(json_content)
212
213 return config, user, date, market_id, http_requests, json_content
214
215 def parse_requests(self, json_content):
216 http_requests = []
217 for element in json_content:
218 if element["type"] != "http_request":
219 continue
220 http_requests.append(element)
221 return http_requests
222
223 def parse_config(self, json_content):
224 market_info = None
225 for element in json_content:
226 if element["type"] != "market":
227 continue
228 market_info = element
229 assert market_info is not None, "Couldn't find market element"
230
231 args = market_info["args"]
232 config = []
233 for arg in ["before", "after", "quiet", "debug"]:
234 if args.get(arg, False):
235 config.append("--{}".format(arg))
236 for arg in ["parallel", "report_db"]:
237 if not args.get(arg, False):
238 config.append("--no-{}".format(arg.replace("_", "-")))
239 for action in args.get("action", []):
240 config.extend(["--action", action])
241 if args.get("report_path") is not None:
242 config.extend(["--report-path", args.get("report_path")])
243 if args.get("user") is not None:
244 config.extend(["--user", args.get("user")])
245 config.extend(["--config", ""])
246
247 user = market_info["user_id"]
248 date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f")
249 market_id = market_info["market_id"]
250 return config, user, date, market_id
251
252 def requests_by_market(self):
253 r = {
254 None: []
255 }
256 got_common = False
257 for user_id, market_id, http_requests, report_lines in self.reports.values():
258 r[str(market_id)] = []
259 for http_request in http_requests:
260 if http_request["market_id"] is None:
261 if not got_common:
262 r[None].append(http_request)
263 else:
264 r[str(market_id)].append(http_request)
265 got_common = True
266 return r
267
268 def setUp(self):
269 if not hasattr(self, "files"):
270 raise "This class expects to be inherited with a class defining self.files in setUp"
271
272 self.reports = {}
273 self.start_date = datetime.datetime.now()
274 self.config = []
275 for f in self.files:
276 self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f)
277 if date < self.start_date:
278 self.start_date = date
279 self.reports[f] = [user_id, market_id, http_requests, report_lines]
280
281 DatabaseMock.start(self.reports, "--no-report-db" not in self.config)
282 RequestsMock.start(self.requests_by_market())
283 FileMock.start()
284 TimeMock.start(self.start_date)
285
286 def base_test(self):
287 import main
288 main.main(self.config)
289 RequestsMock.check_calls(self)
290 DatabaseMock.check_calls(self)
291 FileMock.check_calls(self)
292
293 def tearDown(self):
294 TimeMock.stop()
295 FileMock.stop()
296 RequestsMock.stop()
297 DatabaseMock.stop()
298
299 import glob
300 for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True):
301 json_files = glob.glob("{}/*.json".format(dirfile))
302 if len(json_files) > 0:
303 name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1]
304 cname = "".join(list(map(lambda x: x.capitalize(), name.split("_"))))
305
306 globals()[cname] = type(cname,
307 (AcceptanceTestCase,unittest.TestCase), {
308 "files": json_files,
309 "test_{}".format(name): AcceptanceTestCase.base_test
310 })
311
312 if __name__ == '__main__':
313 unittest.main()
314