aboutsummaryrefslogtreecommitdiff
path: root/test_acceptance.py
diff options
context:
space:
mode:
Diffstat (limited to 'test_acceptance.py')
-rw-r--r--test_acceptance.py314
1 files changed, 314 insertions, 0 deletions
diff --git a/test_acceptance.py b/test_acceptance.py
new file mode 100644
index 0000000..88a2dd4
--- /dev/null
+++ b/test_acceptance.py
@@ -0,0 +1,314 @@
1import requests
2import requests_mock
3import sys, os
4import time, datetime
5import unittest
6from unittest import mock
7from ssl import SSLError
8from decimal import Decimal
9import simplejson as json
10import psycopg2
11from io import StringIO
12
13class FileMock:
14 @classmethod
15 def start(cls):
16 cls.file_mock = mock.patch("market.open")
17 cls.os_mock = mock.patch("os.makedirs")
18 cls.stdout_mock = mock.patch('sys.stdout', new_callable=StringIO)
19 cls.stdout_mock.start()
20 cls.os_mock.start()
21 cls.file_mock.start()
22
23 @classmethod
24 def check_calls(cls, tester):
25 pass
26 #raise NotImplementedError("Todo")
27
28 @classmethod
29 def stop(cls):
30 cls.file_mock.stop()
31 cls.stdout_mock.stop()
32 cls.os_mock.stop()
33
34class DatabaseMock:
35 rows = []
36 report_db = False
37 db_patch = None
38 cursor = None
39 total_report_lines = 0
40 db_mock = None
41 requests = []
42
43 @classmethod
44 def start(cls, reports, report_db):
45 cls.report_db = report_db
46 cls.rows = []
47 cls.total_report_lines= 0
48 for user_id, market_id, http_requests, report_lines in reports.values():
49 cls.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) )
50 cls.total_report_lines += len(report_lines)
51
52 connect_mock = mock.Mock()
53 cls.cursor = mock.MagicMock()
54 connect_mock.cursor.return_value = cls.cursor
55 def _execute(request, *args):
56 cls.requests.append(request)
57 cls.cursor.execute.side_effect = _execute
58 cls.cursor.__iter__.return_value = cls.rows
59
60 cls.db_patch = mock.patch("psycopg2.connect")
61 cls.db_mock = cls.db_patch.start()
62 cls.db_mock.return_value = connect_mock
63
64 @classmethod
65 def check_calls(cls, tester):
66 if cls.report_db:
67 tester.assertEqual(1 + len(cls.rows), cls.db_mock.call_count)
68 tester.assertEqual(1 + len(cls.rows) + cls.total_report_lines, cls.cursor.execute.call_count)
69 else:
70 tester.assertEqual(1, cls.db_mock.call_count)
71 tester.assertEqual(1, cls.cursor.execute.call_count)
72
73 @classmethod
74 def stop(cls):
75 cls.db_patch.stop()
76 cls.db_mock = None
77 cls.cursor = None
78 cls.total_report_lines = 0
79 cls.rows = []
80 cls.report_db = False
81 cls.requests = []
82
83class RequestsMock:
84 adapter = None
85 mocks = {}
86 last_https = {}
87 request_patch = []
88
89 @classmethod
90 def start(cls, reports):
91 cls.adapter = requests_mock.Adapter()
92 cls.adapter.register_uri(requests_mock.ANY, requests_mock.ANY,
93 exc=requests_mock.exceptions.MockException("Not stubbed URL"))
94 true_session = requests.Session
95
96 cls.mocks = {}
97
98 for market_id, elements in reports.items():
99 for element in elements:
100 method = element["method"]
101 url = element["url"]
102 cls.mocks \
103 .setdefault((method, url), {}) \
104 .setdefault(market_id, []) \
105 .append(element)
106
107 for ((method, url), elements) in cls.mocks.items():
108 cls.adapter.register_uri(method, url, text=cls.callback_func(elements), complete_qs=True)
109 def _session():
110 session = true_session()
111 session.get_adapter = lambda url: cls.adapter
112 return session
113 cls.request_patch = [
114 mock.patch.object(requests.sessions, "Session", new=_session),
115 mock.patch.object(requests, "Session", new=_session)
116 ]
117 for patch in cls.request_patch:
118 patch.start()
119
120 @classmethod
121 def stop(cls):
122 for patch in cls.request_patch:
123 patch.stop()
124 cls.request_patch = []
125 cls.last_https = {}
126 cls.mocks = {}
127 cls.adapter = None
128
129 @classmethod
130 def check_calls(cls, tester):
131 for (method, url), elements in cls.mocks.items():
132 for market_id, element in elements.items():
133 tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id))
134
135 @classmethod
136 def clean_body(cls, body):
137 if body is None:
138 return None
139 import re
140 if isinstance(body, bytes):
141 body = body.decode()
142 body = re.sub(r"&nonce=\d*$", "", body)
143 body = re.sub(r"nonce=\d*&?", "", body)
144 return body
145
146 @classmethod
147 def callback_func(cls, elements):
148 def callback(request, context):
149 try:
150 element = elements[request.headers.get("X-market-id")].pop(0)
151 except (IndexError, KeyError):
152 raise RuntimeError("Unexpected call")
153 if element["response"] is None and element["response_same_as"] is not None:
154 element["response"] = cls.last_https[element["response_same_as"]]
155 elif element["response"] is not None:
156 cls.last_https[element["date"]] = element["response"]
157
158 assert cls.clean_body(request.body) == \
159 cls.clean_body(element["body"]), "Body does not match"
160 context.status_code = element["status"]
161 if "error" in element:
162 if element["error"] == "SSLError":
163 raise SSLError(element["error_message"])
164 else:
165 raise getattr(requests.exceptions, element["error"])(element["error_message"])
166 return element["response"]
167 return callback
168
169class TimeMock:
170 delta = 0
171 true_time = time.time
172 time_patch = None
173 datetime_patch = None
174
175 @classmethod
176 def start(cls, start_date):
177 cls.delta = (datetime.datetime.now() - start_date).total_seconds()
178
179 class fake_datetime(datetime.datetime):
180 @classmethod
181 def now(cls, tz=None):
182 if tz is None:
183 return cls.fromtimestamp(time.time())
184 else:
185 return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz))
186
187 cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep)
188 cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime)
189 cls.time_patch.start()
190 cls.datetime_patch.start()
191
192 @classmethod
193 def stop(cls):
194 cls.delta = 0
195 cls.datetime_patch.stop()
196 cls.time_patch.stop()
197
198 @classmethod
199 def fake_time(cls):
200 return cls.true_time() - cls.delta
201
202 @classmethod
203 def fake_sleep(cls, duration):
204 cls.delta -= duration
205
206class AcceptanceTestCase():
207 def parse_file(self, report_file):
208 with open(report_file, "rb") as f:
209 json_content = json.load(f, parse_float=Decimal)
210 config, user, date, market_id = self.parse_config(json_content)
211 http_requests = self.parse_requests(json_content)
212
213 return config, user, date, market_id, http_requests, json_content
214
215 def parse_requests(self, json_content):
216 http_requests = []
217 for element in json_content:
218 if element["type"] != "http_request":
219 continue
220 http_requests.append(element)
221 return http_requests
222
223 def parse_config(self, json_content):
224 market_info = None
225 for element in json_content:
226 if element["type"] != "market":
227 continue
228 market_info = element
229 assert market_info is not None, "Couldn't find market element"
230
231 args = market_info["args"]
232 config = []
233 for arg in ["before", "after", "quiet", "debug"]:
234 if args.get(arg, False):
235 config.append("--{}".format(arg))
236 for arg in ["parallel", "report_db"]:
237 if not args.get(arg, False):
238 config.append("--no-{}".format(arg.replace("_", "-")))
239 for action in args.get("action", []):
240 config.extend(["--action", action])
241 if args.get("report_path") is not None:
242 config.extend(["--report-path", args.get("report_path")])
243 if args.get("user") is not None:
244 config.extend(["--user", args.get("user")])
245 config.extend(["--config", ""])
246
247 user = market_info["user_id"]
248 date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f")
249 market_id = market_info["market_id"]
250 return config, user, date, market_id
251
252 def requests_by_market(self):
253 r = {
254 None: []
255 }
256 got_common = False
257 for user_id, market_id, http_requests, report_lines in self.reports.values():
258 r[str(market_id)] = []
259 for http_request in http_requests:
260 if http_request["market_id"] is None:
261 if not got_common:
262 r[None].append(http_request)
263 else:
264 r[str(market_id)].append(http_request)
265 got_common = True
266 return r
267
268 def setUp(self):
269 if not hasattr(self, "files"):
270 raise "This class expects to be inherited with a class defining self.files in setUp"
271
272 self.reports = {}
273 self.start_date = datetime.datetime.now()
274 self.config = []
275 for f in self.files:
276 self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f)
277 if date < self.start_date:
278 self.start_date = date
279 self.reports[f] = [user_id, market_id, http_requests, report_lines]
280
281 DatabaseMock.start(self.reports, "--no-report-db" not in self.config)
282 RequestsMock.start(self.requests_by_market())
283 FileMock.start()
284 TimeMock.start(self.start_date)
285
286 def base_test(self):
287 import main
288 main.main(self.config)
289 RequestsMock.check_calls(self)
290 DatabaseMock.check_calls(self)
291 FileMock.check_calls(self)
292
293 def tearDown(self):
294 TimeMock.stop()
295 FileMock.stop()
296 RequestsMock.stop()
297 DatabaseMock.stop()
298
299import glob
300for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True):
301 json_files = glob.glob("{}/*.json".format(dirfile))
302 if len(json_files) > 0:
303 name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1]
304 cname = "".join(list(map(lambda x: x.capitalize(), name.split("_"))))
305
306 globals()[cname] = type(cname,
307 (AcceptanceTestCase,unittest.TestCase), {
308 "files": json_files,
309 "test_{}".format(name): AcceptanceTestCase.base_test
310 })
311
312if __name__ == '__main__':
313 unittest.main()
314