aboutsummaryrefslogtreecommitdiff
path: root/test_acceptance.py
diff options
context:
space:
mode:
Diffstat (limited to 'test_acceptance.py')
-rw-r--r--test_acceptance.py321
1 files changed, 184 insertions, 137 deletions
diff --git a/test_acceptance.py b/test_acceptance.py
index 88a2dd4..3633928 100644
--- a/test_acceptance.py
+++ b/test_acceptance.py
@@ -9,166 +9,203 @@ from decimal import Decimal
9import simplejson as json 9import simplejson as json
10import psycopg2 10import psycopg2
11from io import StringIO 11from io import StringIO
12import re
13import functools
14import glob
15
16import main
12 17
13class FileMock: 18class FileMock:
14 @classmethod 19 def __init__(self, log_files, quiet, tester):
15 def start(cls): 20 self.tester = tester
16 cls.file_mock = mock.patch("market.open") 21 self.log_files = []
17 cls.os_mock = mock.patch("os.makedirs") 22 if log_files is not None and len(log_files) > 0:
18 cls.stdout_mock = mock.patch('sys.stdout', new_callable=StringIO) 23 self.read_log_files(log_files)
19 cls.stdout_mock.start() 24 self.quiet = quiet
20 cls.os_mock.start() 25 self.patches = [
21 cls.file_mock.start() 26 mock.patch("market.open"),
27 mock.patch("os.makedirs"),
28 mock.patch("sys.stdout", new_callable=StringIO),
29 ]
30 self.mocks = []
22 31
23 @classmethod 32 def start(self):
24 def check_calls(cls, tester): 33 for patch in self.patches:
25 pass 34 self.mocks.append(patch.start())
26 #raise NotImplementedError("Todo") 35 self.stdout = self.mocks[-1]
27 36
28 @classmethod 37 def check_calls(self):
29 def stop(cls): 38 stdout = self.stdout.getvalue()
30 cls.file_mock.stop() 39 if self.quiet:
31 cls.stdout_mock.stop() 40 self.tester.assertEqual("", stdout)
32 cls.os_mock.stop() 41 else:
42 log = self.strip_log(stdout)
43 if len(self.log_files) != 0:
44 split_logs = log.split("\n")
45 self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs))
46 try:
47 for log_file in self.log_files:
48 for line in log_file:
49 split_logs.pop(split_logs.index(line))
50 except ValueError:
51 if not line.startswith("[Worker] "):
52 self.tester.fail("« {} » not found in log file {}".format(line, split_logs))
53 # Le fichier de log est écrit
54 # Le fichier de log est printed uniquement si non quiet
55 # Le rapport est écrit si pertinent
56 # Le rapport contient le bon nombre de lignes
57
58 def stop(self):
59 for patch in self.patches[::-1]:
60 patch.stop()
61 self.mocks.pop()
33 62
34class DatabaseMock: 63 def strip_log(self, log):
35 rows = [] 64 log = log.replace("\n\n", "\n")
36 report_db = False 65 return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE)
37 db_patch = None
38 cursor = None
39 total_report_lines = 0
40 db_mock = None
41 requests = []
42 66
43 @classmethod 67 def read_log_files(self, log_files):
44 def start(cls, reports, report_db): 68 for log_file in log_files:
45 cls.report_db = report_db 69 with open(log_file, "r") as f:
46 cls.rows = [] 70 log = self.strip_log(f.read()).split("\n")
47 cls.total_report_lines= 0 71 if len(log[-1]) == 0:
48 for user_id, market_id, http_requests, report_lines in reports.values(): 72 log.pop()
49 cls.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) ) 73 self.log_files.append(log)
50 cls.total_report_lines += len(report_lines) 74
75class DatabaseMock:
76 def __init__(self, tester, reports, report_db):
77 self.tester = tester
78 self.reports = reports
79 self.report_db = report_db
80 self.rows = []
81 self.total_report_lines= 0
82 self.requests = []
83 for user_id, market_id, http_requests, report_lines in self.reports.values():
84 self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) )
85 self.total_report_lines += len(report_lines)
51 86
87 def start(self):
52 connect_mock = mock.Mock() 88 connect_mock = mock.Mock()
53 cls.cursor = mock.MagicMock() 89 self.cursor = mock.MagicMock()
54 connect_mock.cursor.return_value = cls.cursor 90 connect_mock.cursor.return_value = self.cursor
55 def _execute(request, *args): 91 def _execute(request, *args):
56 cls.requests.append(request) 92 self.requests.append(request)
57 cls.cursor.execute.side_effect = _execute 93 self.cursor.execute.side_effect = _execute
58 cls.cursor.__iter__.return_value = cls.rows 94 self.cursor.__iter__.return_value = self.rows
59 95
60 cls.db_patch = mock.patch("psycopg2.connect") 96 self.db_patch = mock.patch("psycopg2.connect")
61 cls.db_mock = cls.db_patch.start() 97 self.db_mock = self.db_patch.start()
62 cls.db_mock.return_value = connect_mock 98 self.db_mock.return_value = connect_mock
63 99
64 @classmethod 100 def check_calls(self):
65 def check_calls(cls, tester): 101 if self.report_db:
66 if cls.report_db: 102 self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count)
67 tester.assertEqual(1 + len(cls.rows), cls.db_mock.call_count) 103 self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count)
68 tester.assertEqual(1 + len(cls.rows) + cls.total_report_lines, cls.cursor.execute.call_count)
69 else: 104 else:
70 tester.assertEqual(1, cls.db_mock.call_count) 105 self.tester.assertEqual(1, self.db_mock.call_count)
71 tester.assertEqual(1, cls.cursor.execute.call_count) 106 self.tester.assertEqual(1, self.cursor.execute.call_count)
72 107
73 @classmethod 108 def stop(self):
74 def stop(cls): 109 self.db_patch.stop()
75 cls.db_patch.stop()
76 cls.db_mock = None
77 cls.cursor = None
78 cls.total_report_lines = 0
79 cls.rows = []
80 cls.report_db = False
81 cls.requests = []
82 110
83class RequestsMock: 111class RequestsMock:
84 adapter = None 112 def __init__(self, tester):
85 mocks = {} 113 self.tester = tester
86 last_https = {} 114 self.reports = tester.requests_by_market()
87 request_patch = [] 115
88 116 self.last_https = {}
89 @classmethod 117 self.error_calls = []
90 def start(cls, reports): 118 self.mocker = requests_mock.Mocker()
91 cls.adapter = requests_mock.Adapter() 119 def not_stubbed(*args):
92 cls.adapter.register_uri(requests_mock.ANY, requests_mock.ANY, 120 self.error_calls.append([args[0].method, args[0].url])
93 exc=requests_mock.exceptions.MockException("Not stubbed URL")) 121 raise requests_mock.exceptions.MockException("Not stubbed URL")
94 true_session = requests.Session 122 self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY,
95 123 text=not_stubbed)
96 cls.mocks = {} 124
97 125 self.mocks = {}
98 for market_id, elements in reports.items(): 126
127 for market_id, elements in self.reports.items():
99 for element in elements: 128 for element in elements:
100 method = element["method"] 129 method = element["method"]
101 url = element["url"] 130 url = element["url"]
102 cls.mocks \ 131 self.mocks \
103 .setdefault((method, url), {}) \ 132 .setdefault((method, url), {}) \
104 .setdefault(market_id, []) \ 133 .setdefault(market_id, []) \
105 .append(element) 134 .append(element)
106 135
107 for ((method, url), elements) in cls.mocks.items(): 136 for ((method, url), elements) in self.mocks.items():
108 cls.adapter.register_uri(method, url, text=cls.callback_func(elements), complete_qs=True) 137 self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True)
109 def _session():
110 session = true_session()
111 session.get_adapter = lambda url: cls.adapter
112 return session
113 cls.request_patch = [
114 mock.patch.object(requests.sessions, "Session", new=_session),
115 mock.patch.object(requests, "Session", new=_session)
116 ]
117 for patch in cls.request_patch:
118 patch.start()
119 138
120 @classmethod 139 def start(self):
121 def stop(cls): 140 self.mocker.start()
122 for patch in cls.request_patch:
123 patch.stop()
124 cls.request_patch = []
125 cls.last_https = {}
126 cls.mocks = {}
127 cls.adapter = None
128 141
129 @classmethod 142 def stop(self):
130 def check_calls(cls, tester): 143 self.mocker.stop()
131 for (method, url), elements in cls.mocks.items(): 144
145 def check_calls(self):
146 self.tester.assertEqual([], self.error_calls)
147 for (method, url), elements in self.mocks.items():
132 for market_id, element in elements.items(): 148 for market_id, element in elements.items():
133 tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id)) 149 self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id))
134 150
135 @classmethod 151 def clean_body(self, body):
136 def clean_body(cls, body):
137 if body is None: 152 if body is None:
138 return None 153 return None
139 import re
140 if isinstance(body, bytes): 154 if isinstance(body, bytes):
141 body = body.decode() 155 body = body.decode()
142 body = re.sub(r"&nonce=\d*$", "", body) 156 body = re.sub(r"&nonce=\d*$", "", body)
143 body = re.sub(r"nonce=\d*&?", "", body) 157 body = re.sub(r"nonce=\d*&?", "", body)
144 return body 158 return body
145 159
146 @classmethod 160def callback(self, elements, request, context):
147 def callback_func(cls, elements): 161 try:
148 def callback(request, context): 162 element = elements[request.headers.get("X-market-id")].pop(0)
149 try: 163 except (IndexError, KeyError):
150 element = elements[request.headers.get("X-market-id")].pop(0) 164 self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")])
151 except (IndexError, KeyError): 165 raise RuntimeError("Unexpected call")
152 raise RuntimeError("Unexpected call") 166 if element["response"] is None and element["response_same_as"] is not None:
153 if element["response"] is None and element["response_same_as"] is not None: 167 element["response"] = self.last_https[element["response_same_as"]]
154 element["response"] = cls.last_https[element["response_same_as"]] 168 elif element["response"] is not None:
155 elif element["response"] is not None: 169 self.last_https[element["date"]] = element["response"]
156 cls.last_https[element["date"]] = element["response"] 170
157 171 assert self.clean_body(request.body) == \
158 assert cls.clean_body(request.body) == \ 172 self.clean_body(element["body"]), "Body does not match"
159 cls.clean_body(element["body"]), "Body does not match" 173 context.status_code = element["status"]
160 context.status_code = element["status"] 174 if "error" in element:
161 if "error" in element: 175 if element["error"] == "SSLError":
162 if element["error"] == "SSLError": 176 raise SSLError(element["error_message"])
163 raise SSLError(element["error_message"]) 177 else:
164 else: 178 raise getattr(requests.exceptions, element["error"])(element["error_message"])
165 raise getattr(requests.exceptions, element["error"])(element["error_message"]) 179 return element["response"]
166 return element["response"] 180
167 return callback 181class GlobalVariablesMock:
182 def start(self):
183 import market
184 import store
185
186 self.patchers = [
187 mock.patch.multiple(market.Portfolio,
188 data=store.LockedVar(None),
189 liquidities=store.LockedVar({}),
190 last_date=store.LockedVar(None),
191 report=store.LockedVar(store.ReportStore(None, no_http_dup=True)),
192 worker=None,
193 worker_tag="",
194 worker_notify=None,
195 worker_started=False,
196 callback=None)
197 ]
198 for patcher in self.patchers:
199 patcher.start()
200
201 def stop(self):
202 pass
203
168 204
169class TimeMock: 205class TimeMock:
170 delta = 0 206 delta = 0
171 true_time = time.time 207 true_time = time.time
208 true_sleep = time.sleep
172 time_patch = None 209 time_patch = None
173 datetime_patch = None 210 datetime_patch = None
174 211
@@ -202,6 +239,7 @@ class TimeMock:
202 @classmethod 239 @classmethod
203 def fake_sleep(cls, duration): 240 def fake_sleep(cls, duration):
204 cls.delta -= duration 241 cls.delta -= duration
242 cls.true_sleep(0.2)
205 243
206class AcceptanceTestCase(): 244class AcceptanceTestCase():
207 def parse_file(self, report_file): 245 def parse_file(self, report_file):
@@ -236,7 +274,7 @@ class AcceptanceTestCase():
236 for arg in ["parallel", "report_db"]: 274 for arg in ["parallel", "report_db"]:
237 if not args.get(arg, False): 275 if not args.get(arg, False):
238 config.append("--no-{}".format(arg.replace("_", "-"))) 276 config.append("--no-{}".format(arg.replace("_", "-")))
239 for action in args.get("action", []): 277 for action in (args.get("action", []) or []):
240 config.extend(["--action", action]) 278 config.extend(["--action", action])
241 if args.get("report_path") is not None: 279 if args.get("report_path") is not None:
242 config.extend(["--report-path", args.get("report_path")]) 280 config.extend(["--report-path", args.get("report_path")])
@@ -267,7 +305,9 @@ class AcceptanceTestCase():
267 305
268 def setUp(self): 306 def setUp(self):
269 if not hasattr(self, "files"): 307 if not hasattr(self, "files"):
270 raise "This class expects to be inherited with a class defining self.files in setUp" 308 raise "This class expects to be inherited with a class defining 'files' variable"
309 if not hasattr(self, "log_files"):
310 self.log_files = []
271 311
272 self.reports = {} 312 self.reports = {}
273 self.start_date = datetime.datetime.now() 313 self.start_date = datetime.datetime.now()
@@ -278,33 +318,40 @@ class AcceptanceTestCase():
278 self.start_date = date 318 self.start_date = date
279 self.reports[f] = [user_id, market_id, http_requests, report_lines] 319 self.reports[f] = [user_id, market_id, http_requests, report_lines]
280 320
281 DatabaseMock.start(self.reports, "--no-report-db" not in self.config) 321 self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config)
282 RequestsMock.start(self.requests_by_market()) 322 self.requests_mock = RequestsMock(self)
283 FileMock.start() 323 self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self)
324 self.global_variables_mock = GlobalVariablesMock()
325
326 self.database_mock.start()
327 self.requests_mock.start()
328 self.file_mock.start()
329 self.global_variables_mock.start()
284 TimeMock.start(self.start_date) 330 TimeMock.start(self.start_date)
285 331
286 def base_test(self): 332 def base_test(self):
287 import main
288 main.main(self.config) 333 main.main(self.config)
289 RequestsMock.check_calls(self) 334 self.requests_mock.check_calls()
290 DatabaseMock.check_calls(self) 335 self.database_mock.check_calls()
291 FileMock.check_calls(self) 336 self.file_mock.check_calls()
292 337
293 def tearDown(self): 338 def tearDown(self):
294 TimeMock.stop() 339 TimeMock.stop()
295 FileMock.stop() 340 self.global_variables_mock.stop()
296 RequestsMock.stop() 341 self.file_mock.stop()
297 DatabaseMock.stop() 342 self.requests_mock.stop()
343 self.database_mock.stop()
298 344
299import glob
300for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True): 345for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True):
301 json_files = glob.glob("{}/*.json".format(dirfile)) 346 json_files = glob.glob("{}/*.json".format(dirfile))
347 log_files = glob.glob("{}/*.log".format(dirfile))
302 if len(json_files) > 0: 348 if len(json_files) > 0:
303 name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1] 349 name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1]
304 cname = "".join(list(map(lambda x: x.capitalize(), name.split("_")))) 350 cname = "".join(list(map(lambda x: x.capitalize(), name.split("_"))))
305 351
306 globals()[cname] = type(cname, 352 globals()[cname] = type(cname,
307 (AcceptanceTestCase,unittest.TestCase), { 353 (AcceptanceTestCase,unittest.TestCase), {
354 "log_files": log_files,
308 "files": json_files, 355 "files": json_files,
309 "test_{}".format(name): AcceptanceTestCase.base_test 356 "test_{}".format(name): AcceptanceTestCase.base_test
310 }) 357 })