diff options
Diffstat (limited to 'test_acceptance.py')
-rw-r--r-- | test_acceptance.py | 314 |
1 files changed, 314 insertions, 0 deletions
diff --git a/test_acceptance.py b/test_acceptance.py new file mode 100644 index 0000000..88a2dd4 --- /dev/null +++ b/test_acceptance.py | |||
@@ -0,0 +1,314 @@ | |||
1 | import requests | ||
2 | import requests_mock | ||
3 | import sys, os | ||
4 | import time, datetime | ||
5 | import unittest | ||
6 | from unittest import mock | ||
7 | from ssl import SSLError | ||
8 | from decimal import Decimal | ||
9 | import simplejson as json | ||
10 | import psycopg2 | ||
11 | from io import StringIO | ||
12 | |||
13 | class FileMock: | ||
14 | @classmethod | ||
15 | def start(cls): | ||
16 | cls.file_mock = mock.patch("market.open") | ||
17 | cls.os_mock = mock.patch("os.makedirs") | ||
18 | cls.stdout_mock = mock.patch('sys.stdout', new_callable=StringIO) | ||
19 | cls.stdout_mock.start() | ||
20 | cls.os_mock.start() | ||
21 | cls.file_mock.start() | ||
22 | |||
23 | @classmethod | ||
24 | def check_calls(cls, tester): | ||
25 | pass | ||
26 | #raise NotImplementedError("Todo") | ||
27 | |||
28 | @classmethod | ||
29 | def stop(cls): | ||
30 | cls.file_mock.stop() | ||
31 | cls.stdout_mock.stop() | ||
32 | cls.os_mock.stop() | ||
33 | |||
34 | class DatabaseMock: | ||
35 | rows = [] | ||
36 | report_db = False | ||
37 | db_patch = None | ||
38 | cursor = None | ||
39 | total_report_lines = 0 | ||
40 | db_mock = None | ||
41 | requests = [] | ||
42 | |||
43 | @classmethod | ||
44 | def start(cls, reports, report_db): | ||
45 | cls.report_db = report_db | ||
46 | cls.rows = [] | ||
47 | cls.total_report_lines= 0 | ||
48 | for user_id, market_id, http_requests, report_lines in reports.values(): | ||
49 | cls.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) ) | ||
50 | cls.total_report_lines += len(report_lines) | ||
51 | |||
52 | connect_mock = mock.Mock() | ||
53 | cls.cursor = mock.MagicMock() | ||
54 | connect_mock.cursor.return_value = cls.cursor | ||
55 | def _execute(request, *args): | ||
56 | cls.requests.append(request) | ||
57 | cls.cursor.execute.side_effect = _execute | ||
58 | cls.cursor.__iter__.return_value = cls.rows | ||
59 | |||
60 | cls.db_patch = mock.patch("psycopg2.connect") | ||
61 | cls.db_mock = cls.db_patch.start() | ||
62 | cls.db_mock.return_value = connect_mock | ||
63 | |||
64 | @classmethod | ||
65 | def check_calls(cls, tester): | ||
66 | if cls.report_db: | ||
67 | tester.assertEqual(1 + len(cls.rows), cls.db_mock.call_count) | ||
68 | tester.assertEqual(1 + len(cls.rows) + cls.total_report_lines, cls.cursor.execute.call_count) | ||
69 | else: | ||
70 | tester.assertEqual(1, cls.db_mock.call_count) | ||
71 | tester.assertEqual(1, cls.cursor.execute.call_count) | ||
72 | |||
73 | @classmethod | ||
74 | def stop(cls): | ||
75 | cls.db_patch.stop() | ||
76 | cls.db_mock = None | ||
77 | cls.cursor = None | ||
78 | cls.total_report_lines = 0 | ||
79 | cls.rows = [] | ||
80 | cls.report_db = False | ||
81 | cls.requests = [] | ||
82 | |||
83 | class RequestsMock: | ||
84 | adapter = None | ||
85 | mocks = {} | ||
86 | last_https = {} | ||
87 | request_patch = [] | ||
88 | |||
89 | @classmethod | ||
90 | def start(cls, reports): | ||
91 | cls.adapter = requests_mock.Adapter() | ||
92 | cls.adapter.register_uri(requests_mock.ANY, requests_mock.ANY, | ||
93 | exc=requests_mock.exceptions.MockException("Not stubbed URL")) | ||
94 | true_session = requests.Session | ||
95 | |||
96 | cls.mocks = {} | ||
97 | |||
98 | for market_id, elements in reports.items(): | ||
99 | for element in elements: | ||
100 | method = element["method"] | ||
101 | url = element["url"] | ||
102 | cls.mocks \ | ||
103 | .setdefault((method, url), {}) \ | ||
104 | .setdefault(market_id, []) \ | ||
105 | .append(element) | ||
106 | |||
107 | for ((method, url), elements) in cls.mocks.items(): | ||
108 | cls.adapter.register_uri(method, url, text=cls.callback_func(elements), complete_qs=True) | ||
109 | def _session(): | ||
110 | session = true_session() | ||
111 | session.get_adapter = lambda url: cls.adapter | ||
112 | return session | ||
113 | cls.request_patch = [ | ||
114 | mock.patch.object(requests.sessions, "Session", new=_session), | ||
115 | mock.patch.object(requests, "Session", new=_session) | ||
116 | ] | ||
117 | for patch in cls.request_patch: | ||
118 | patch.start() | ||
119 | |||
120 | @classmethod | ||
121 | def stop(cls): | ||
122 | for patch in cls.request_patch: | ||
123 | patch.stop() | ||
124 | cls.request_patch = [] | ||
125 | cls.last_https = {} | ||
126 | cls.mocks = {} | ||
127 | cls.adapter = None | ||
128 | |||
129 | @classmethod | ||
130 | def check_calls(cls, tester): | ||
131 | for (method, url), elements in cls.mocks.items(): | ||
132 | for market_id, element in elements.items(): | ||
133 | tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id)) | ||
134 | |||
135 | @classmethod | ||
136 | def clean_body(cls, body): | ||
137 | if body is None: | ||
138 | return None | ||
139 | import re | ||
140 | if isinstance(body, bytes): | ||
141 | body = body.decode() | ||
142 | body = re.sub(r"&nonce=\d*$", "", body) | ||
143 | body = re.sub(r"nonce=\d*&?", "", body) | ||
144 | return body | ||
145 | |||
146 | @classmethod | ||
147 | def callback_func(cls, elements): | ||
148 | def callback(request, context): | ||
149 | try: | ||
150 | element = elements[request.headers.get("X-market-id")].pop(0) | ||
151 | except (IndexError, KeyError): | ||
152 | raise RuntimeError("Unexpected call") | ||
153 | if element["response"] is None and element["response_same_as"] is not None: | ||
154 | element["response"] = cls.last_https[element["response_same_as"]] | ||
155 | elif element["response"] is not None: | ||
156 | cls.last_https[element["date"]] = element["response"] | ||
157 | |||
158 | assert cls.clean_body(request.body) == \ | ||
159 | cls.clean_body(element["body"]), "Body does not match" | ||
160 | context.status_code = element["status"] | ||
161 | if "error" in element: | ||
162 | if element["error"] == "SSLError": | ||
163 | raise SSLError(element["error_message"]) | ||
164 | else: | ||
165 | raise getattr(requests.exceptions, element["error"])(element["error_message"]) | ||
166 | return element["response"] | ||
167 | return callback | ||
168 | |||
169 | class TimeMock: | ||
170 | delta = 0 | ||
171 | true_time = time.time | ||
172 | time_patch = None | ||
173 | datetime_patch = None | ||
174 | |||
175 | @classmethod | ||
176 | def start(cls, start_date): | ||
177 | cls.delta = (datetime.datetime.now() - start_date).total_seconds() | ||
178 | |||
179 | class fake_datetime(datetime.datetime): | ||
180 | @classmethod | ||
181 | def now(cls, tz=None): | ||
182 | if tz is None: | ||
183 | return cls.fromtimestamp(time.time()) | ||
184 | else: | ||
185 | return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz)) | ||
186 | |||
187 | cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep) | ||
188 | cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime) | ||
189 | cls.time_patch.start() | ||
190 | cls.datetime_patch.start() | ||
191 | |||
192 | @classmethod | ||
193 | def stop(cls): | ||
194 | cls.delta = 0 | ||
195 | cls.datetime_patch.stop() | ||
196 | cls.time_patch.stop() | ||
197 | |||
198 | @classmethod | ||
199 | def fake_time(cls): | ||
200 | return cls.true_time() - cls.delta | ||
201 | |||
202 | @classmethod | ||
203 | def fake_sleep(cls, duration): | ||
204 | cls.delta -= duration | ||
205 | |||
206 | class AcceptanceTestCase(): | ||
207 | def parse_file(self, report_file): | ||
208 | with open(report_file, "rb") as f: | ||
209 | json_content = json.load(f, parse_float=Decimal) | ||
210 | config, user, date, market_id = self.parse_config(json_content) | ||
211 | http_requests = self.parse_requests(json_content) | ||
212 | |||
213 | return config, user, date, market_id, http_requests, json_content | ||
214 | |||
215 | def parse_requests(self, json_content): | ||
216 | http_requests = [] | ||
217 | for element in json_content: | ||
218 | if element["type"] != "http_request": | ||
219 | continue | ||
220 | http_requests.append(element) | ||
221 | return http_requests | ||
222 | |||
223 | def parse_config(self, json_content): | ||
224 | market_info = None | ||
225 | for element in json_content: | ||
226 | if element["type"] != "market": | ||
227 | continue | ||
228 | market_info = element | ||
229 | assert market_info is not None, "Couldn't find market element" | ||
230 | |||
231 | args = market_info["args"] | ||
232 | config = [] | ||
233 | for arg in ["before", "after", "quiet", "debug"]: | ||
234 | if args.get(arg, False): | ||
235 | config.append("--{}".format(arg)) | ||
236 | for arg in ["parallel", "report_db"]: | ||
237 | if not args.get(arg, False): | ||
238 | config.append("--no-{}".format(arg.replace("_", "-"))) | ||
239 | for action in args.get("action", []): | ||
240 | config.extend(["--action", action]) | ||
241 | if args.get("report_path") is not None: | ||
242 | config.extend(["--report-path", args.get("report_path")]) | ||
243 | if args.get("user") is not None: | ||
244 | config.extend(["--user", args.get("user")]) | ||
245 | config.extend(["--config", ""]) | ||
246 | |||
247 | user = market_info["user_id"] | ||
248 | date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f") | ||
249 | market_id = market_info["market_id"] | ||
250 | return config, user, date, market_id | ||
251 | |||
252 | def requests_by_market(self): | ||
253 | r = { | ||
254 | None: [] | ||
255 | } | ||
256 | got_common = False | ||
257 | for user_id, market_id, http_requests, report_lines in self.reports.values(): | ||
258 | r[str(market_id)] = [] | ||
259 | for http_request in http_requests: | ||
260 | if http_request["market_id"] is None: | ||
261 | if not got_common: | ||
262 | r[None].append(http_request) | ||
263 | else: | ||
264 | r[str(market_id)].append(http_request) | ||
265 | got_common = True | ||
266 | return r | ||
267 | |||
268 | def setUp(self): | ||
269 | if not hasattr(self, "files"): | ||
270 | raise "This class expects to be inherited with a class defining self.files in setUp" | ||
271 | |||
272 | self.reports = {} | ||
273 | self.start_date = datetime.datetime.now() | ||
274 | self.config = [] | ||
275 | for f in self.files: | ||
276 | self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f) | ||
277 | if date < self.start_date: | ||
278 | self.start_date = date | ||
279 | self.reports[f] = [user_id, market_id, http_requests, report_lines] | ||
280 | |||
281 | DatabaseMock.start(self.reports, "--no-report-db" not in self.config) | ||
282 | RequestsMock.start(self.requests_by_market()) | ||
283 | FileMock.start() | ||
284 | TimeMock.start(self.start_date) | ||
285 | |||
286 | def base_test(self): | ||
287 | import main | ||
288 | main.main(self.config) | ||
289 | RequestsMock.check_calls(self) | ||
290 | DatabaseMock.check_calls(self) | ||
291 | FileMock.check_calls(self) | ||
292 | |||
293 | def tearDown(self): | ||
294 | TimeMock.stop() | ||
295 | FileMock.stop() | ||
296 | RequestsMock.stop() | ||
297 | DatabaseMock.stop() | ||
298 | |||
299 | import glob | ||
300 | for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True): | ||
301 | json_files = glob.glob("{}/*.json".format(dirfile)) | ||
302 | if len(json_files) > 0: | ||
303 | name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1] | ||
304 | cname = "".join(list(map(lambda x: x.capitalize(), name.split("_")))) | ||
305 | |||
306 | globals()[cname] = type(cname, | ||
307 | (AcceptanceTestCase,unittest.TestCase), { | ||
308 | "files": json_files, | ||
309 | "test_{}".format(name): AcceptanceTestCase.base_test | ||
310 | }) | ||
311 | |||
312 | if __name__ == '__main__': | ||
313 | unittest.main() | ||
314 | |||