]>
Commit | Line | Data |
---|---|---|
1 | import requests | |
2 | import requests_mock | |
3 | import sys, os | |
4 | import time, datetime | |
5 | import unittest | |
6 | from unittest import mock | |
7 | from ssl import SSLError | |
8 | from decimal import Decimal | |
9 | import simplejson as json | |
10 | import psycopg2 | |
11 | from io import StringIO | |
12 | ||
13 | class FileMock: | |
14 | @classmethod | |
15 | def start(cls): | |
16 | cls.file_mock = mock.patch("market.open") | |
17 | cls.os_mock = mock.patch("os.makedirs") | |
18 | cls.stdout_mock = mock.patch('sys.stdout', new_callable=StringIO) | |
19 | cls.stdout_mock.start() | |
20 | cls.os_mock.start() | |
21 | cls.file_mock.start() | |
22 | ||
23 | @classmethod | |
24 | def check_calls(cls, tester): | |
25 | pass | |
26 | #raise NotImplementedError("Todo") | |
27 | ||
28 | @classmethod | |
29 | def stop(cls): | |
30 | cls.file_mock.stop() | |
31 | cls.stdout_mock.stop() | |
32 | cls.os_mock.stop() | |
33 | ||
34 | class DatabaseMock: | |
35 | rows = [] | |
36 | report_db = False | |
37 | db_patch = None | |
38 | cursor = None | |
39 | total_report_lines = 0 | |
40 | db_mock = None | |
41 | requests = [] | |
42 | ||
43 | @classmethod | |
44 | def start(cls, reports, report_db): | |
45 | cls.report_db = report_db | |
46 | cls.rows = [] | |
47 | cls.total_report_lines= 0 | |
48 | for user_id, market_id, http_requests, report_lines in reports.values(): | |
49 | cls.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) ) | |
50 | cls.total_report_lines += len(report_lines) | |
51 | ||
52 | connect_mock = mock.Mock() | |
53 | cls.cursor = mock.MagicMock() | |
54 | connect_mock.cursor.return_value = cls.cursor | |
55 | def _execute(request, *args): | |
56 | cls.requests.append(request) | |
57 | cls.cursor.execute.side_effect = _execute | |
58 | cls.cursor.__iter__.return_value = cls.rows | |
59 | ||
60 | cls.db_patch = mock.patch("psycopg2.connect") | |
61 | cls.db_mock = cls.db_patch.start() | |
62 | cls.db_mock.return_value = connect_mock | |
63 | ||
64 | @classmethod | |
65 | def check_calls(cls, tester): | |
66 | if cls.report_db: | |
67 | tester.assertEqual(1 + len(cls.rows), cls.db_mock.call_count) | |
68 | tester.assertEqual(1 + len(cls.rows) + cls.total_report_lines, cls.cursor.execute.call_count) | |
69 | else: | |
70 | tester.assertEqual(1, cls.db_mock.call_count) | |
71 | tester.assertEqual(1, cls.cursor.execute.call_count) | |
72 | ||
73 | @classmethod | |
74 | def stop(cls): | |
75 | cls.db_patch.stop() | |
76 | cls.db_mock = None | |
77 | cls.cursor = None | |
78 | cls.total_report_lines = 0 | |
79 | cls.rows = [] | |
80 | cls.report_db = False | |
81 | cls.requests = [] | |
82 | ||
83 | class RequestsMock: | |
84 | adapter = None | |
85 | mocks = {} | |
86 | last_https = {} | |
87 | request_patch = [] | |
88 | ||
89 | @classmethod | |
90 | def start(cls, reports): | |
91 | cls.adapter = requests_mock.Adapter() | |
92 | cls.adapter.register_uri(requests_mock.ANY, requests_mock.ANY, | |
93 | exc=requests_mock.exceptions.MockException("Not stubbed URL")) | |
94 | true_session = requests.Session | |
95 | ||
96 | cls.mocks = {} | |
97 | ||
98 | for market_id, elements in reports.items(): | |
99 | for element in elements: | |
100 | method = element["method"] | |
101 | url = element["url"] | |
102 | cls.mocks \ | |
103 | .setdefault((method, url), {}) \ | |
104 | .setdefault(market_id, []) \ | |
105 | .append(element) | |
106 | ||
107 | for ((method, url), elements) in cls.mocks.items(): | |
108 | cls.adapter.register_uri(method, url, text=cls.callback_func(elements), complete_qs=True) | |
109 | def _session(): | |
110 | session = true_session() | |
111 | session.get_adapter = lambda url: cls.adapter | |
112 | return session | |
113 | cls.request_patch = [ | |
114 | mock.patch.object(requests.sessions, "Session", new=_session), | |
115 | mock.patch.object(requests, "Session", new=_session) | |
116 | ] | |
117 | for patch in cls.request_patch: | |
118 | patch.start() | |
119 | ||
120 | @classmethod | |
121 | def stop(cls): | |
122 | for patch in cls.request_patch: | |
123 | patch.stop() | |
124 | cls.request_patch = [] | |
125 | cls.last_https = {} | |
126 | cls.mocks = {} | |
127 | cls.adapter = None | |
128 | ||
129 | @classmethod | |
130 | def check_calls(cls, tester): | |
131 | for (method, url), elements in cls.mocks.items(): | |
132 | for market_id, element in elements.items(): | |
133 | tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id)) | |
134 | ||
135 | @classmethod | |
136 | def clean_body(cls, body): | |
137 | if body is None: | |
138 | return None | |
139 | import re | |
140 | if isinstance(body, bytes): | |
141 | body = body.decode() | |
142 | body = re.sub(r"&nonce=\d*$", "", body) | |
143 | body = re.sub(r"nonce=\d*&?", "", body) | |
144 | return body | |
145 | ||
146 | @classmethod | |
147 | def callback_func(cls, elements): | |
148 | def callback(request, context): | |
149 | try: | |
150 | element = elements[request.headers.get("X-market-id")].pop(0) | |
151 | except (IndexError, KeyError): | |
152 | raise RuntimeError("Unexpected call") | |
153 | if element["response"] is None and element["response_same_as"] is not None: | |
154 | element["response"] = cls.last_https[element["response_same_as"]] | |
155 | elif element["response"] is not None: | |
156 | cls.last_https[element["date"]] = element["response"] | |
157 | ||
158 | assert cls.clean_body(request.body) == \ | |
159 | cls.clean_body(element["body"]), "Body does not match" | |
160 | context.status_code = element["status"] | |
161 | if "error" in element: | |
162 | if element["error"] == "SSLError": | |
163 | raise SSLError(element["error_message"]) | |
164 | else: | |
165 | raise getattr(requests.exceptions, element["error"])(element["error_message"]) | |
166 | return element["response"] | |
167 | return callback | |
168 | ||
169 | class TimeMock: | |
170 | delta = 0 | |
171 | true_time = time.time | |
172 | time_patch = None | |
173 | datetime_patch = None | |
174 | ||
175 | @classmethod | |
176 | def start(cls, start_date): | |
177 | cls.delta = (datetime.datetime.now() - start_date).total_seconds() | |
178 | ||
179 | class fake_datetime(datetime.datetime): | |
180 | @classmethod | |
181 | def now(cls, tz=None): | |
182 | if tz is None: | |
183 | return cls.fromtimestamp(time.time()) | |
184 | else: | |
185 | return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz)) | |
186 | ||
187 | cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep) | |
188 | cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime) | |
189 | cls.time_patch.start() | |
190 | cls.datetime_patch.start() | |
191 | ||
192 | @classmethod | |
193 | def stop(cls): | |
194 | cls.delta = 0 | |
195 | cls.datetime_patch.stop() | |
196 | cls.time_patch.stop() | |
197 | ||
198 | @classmethod | |
199 | def fake_time(cls): | |
200 | return cls.true_time() - cls.delta | |
201 | ||
202 | @classmethod | |
203 | def fake_sleep(cls, duration): | |
204 | cls.delta -= duration | |
205 | ||
206 | class AcceptanceTestCase(): | |
207 | def parse_file(self, report_file): | |
208 | with open(report_file, "rb") as f: | |
209 | json_content = json.load(f, parse_float=Decimal) | |
210 | config, user, date, market_id = self.parse_config(json_content) | |
211 | http_requests = self.parse_requests(json_content) | |
212 | ||
213 | return config, user, date, market_id, http_requests, json_content | |
214 | ||
215 | def parse_requests(self, json_content): | |
216 | http_requests = [] | |
217 | for element in json_content: | |
218 | if element["type"] != "http_request": | |
219 | continue | |
220 | http_requests.append(element) | |
221 | return http_requests | |
222 | ||
223 | def parse_config(self, json_content): | |
224 | market_info = None | |
225 | for element in json_content: | |
226 | if element["type"] != "market": | |
227 | continue | |
228 | market_info = element | |
229 | assert market_info is not None, "Couldn't find market element" | |
230 | ||
231 | args = market_info["args"] | |
232 | config = [] | |
233 | for arg in ["before", "after", "quiet", "debug"]: | |
234 | if args.get(arg, False): | |
235 | config.append("--{}".format(arg)) | |
236 | for arg in ["parallel", "report_db"]: | |
237 | if not args.get(arg, False): | |
238 | config.append("--no-{}".format(arg.replace("_", "-"))) | |
239 | for action in args.get("action", []): | |
240 | config.extend(["--action", action]) | |
241 | if args.get("report_path") is not None: | |
242 | config.extend(["--report-path", args.get("report_path")]) | |
243 | if args.get("user") is not None: | |
244 | config.extend(["--user", args.get("user")]) | |
245 | config.extend(["--config", ""]) | |
246 | ||
247 | user = market_info["user_id"] | |
248 | date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f") | |
249 | market_id = market_info["market_id"] | |
250 | return config, user, date, market_id | |
251 | ||
252 | def requests_by_market(self): | |
253 | r = { | |
254 | None: [] | |
255 | } | |
256 | got_common = False | |
257 | for user_id, market_id, http_requests, report_lines in self.reports.values(): | |
258 | r[str(market_id)] = [] | |
259 | for http_request in http_requests: | |
260 | if http_request["market_id"] is None: | |
261 | if not got_common: | |
262 | r[None].append(http_request) | |
263 | else: | |
264 | r[str(market_id)].append(http_request) | |
265 | got_common = True | |
266 | return r | |
267 | ||
268 | def setUp(self): | |
269 | if not hasattr(self, "files"): | |
270 | raise "This class expects to be inherited with a class defining self.files in setUp" | |
271 | ||
272 | self.reports = {} | |
273 | self.start_date = datetime.datetime.now() | |
274 | self.config = [] | |
275 | for f in self.files: | |
276 | self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f) | |
277 | if date < self.start_date: | |
278 | self.start_date = date | |
279 | self.reports[f] = [user_id, market_id, http_requests, report_lines] | |
280 | ||
281 | DatabaseMock.start(self.reports, "--no-report-db" not in self.config) | |
282 | RequestsMock.start(self.requests_by_market()) | |
283 | FileMock.start() | |
284 | TimeMock.start(self.start_date) | |
285 | ||
286 | def base_test(self): | |
287 | import main | |
288 | main.main(self.config) | |
289 | RequestsMock.check_calls(self) | |
290 | DatabaseMock.check_calls(self) | |
291 | FileMock.check_calls(self) | |
292 | ||
293 | def tearDown(self): | |
294 | TimeMock.stop() | |
295 | FileMock.stop() | |
296 | RequestsMock.stop() | |
297 | DatabaseMock.stop() | |
298 | ||
299 | import glob | |
300 | for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True): | |
301 | json_files = glob.glob("{}/*.json".format(dirfile)) | |
302 | if len(json_files) > 0: | |
303 | name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1] | |
304 | cname = "".join(list(map(lambda x: x.capitalize(), name.split("_")))) | |
305 | ||
306 | globals()[cname] = type(cname, | |
307 | (AcceptanceTestCase,unittest.TestCase), { | |
308 | "files": json_files, | |
309 | "test_{}".format(name): AcceptanceTestCase.base_test | |
310 | }) | |
311 | ||
312 | if __name__ == '__main__': | |
313 | unittest.main() | |
314 |