6 from unittest
import mock
7 from ssl
import SSLError
8 from decimal
import Decimal
9 import simplejson
as json
11 from io
import StringIO
19 def __init__(self
, log_files
, quiet
, tester
):
22 if log_files
is not None and len(log_files
) > 0:
23 self
.read_log_files(log_files
)
26 mock
.patch("market.open"),
27 mock
.patch("os.makedirs"),
28 mock
.patch("sys.stdout", new_callable
=StringIO
),
33 for patch
in self
.patches
:
34 self
.mocks
.append(patch
.start())
35 self
.stdout
= self
.mocks
[-1]
37 def check_calls(self
):
38 stdout
= self
.stdout
.getvalue()
40 self
.tester
.assertEqual("", stdout
)
42 log
= self
.strip_log(stdout
)
43 if len(self
.log_files
) != 0:
44 split_logs
= log
.split("\n")
45 self
.tester
.assertEqual(sum(len(f
) for f
in self
.log_files
), len(split_logs
))
47 for log_file
in self
.log_files
:
49 split_logs
.pop(split_logs
.index(line
))
51 if not line
.startswith("[Worker] "):
52 self
.tester
.fail("« {} » not found in log file {}".format(line
, split_logs
))
53 # Le fichier de log est écrit
54 # Le fichier de log est printed uniquement si non quiet
55 # Le rapport est écrit si pertinent
56 # Le rapport contient le bon nombre de lignes
59 for patch
in self
.patches
[::-1]:
63 def strip_log(self
, log
):
64 log
= log
.replace("\n\n", "\n")
65 return re
.sub(r
"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log
, flags
=re
.MULTILINE
)
67 def read_log_files(self
, log_files
):
68 for log_file
in log_files
:
69 with open(log_file
, "r") as f
:
70 log
= self
.strip_log(f
.read()).split("\n")
73 self
.log_files
.append(log
)
76 def __init__(self
, tester
, reports
, report_db
):
78 self
.reports
= reports
79 self
.report_db
= report_db
81 self
.total_report_lines
= 0
83 for user_id
, market_id
, http_requests
, report_lines
in self
.reports
.values():
84 self
.rows
.append( (market_id
, { "key": "key", "secret": "secret" }
, user_id
) )
85 self
.total_report_lines
+= len(report_lines
)
88 connect_mock
= mock
.Mock()
89 self
.cursor
= mock
.MagicMock()
90 connect_mock
.cursor
.return_value
= self
.cursor
91 def _execute(request
, *args
):
92 self
.requests
.append(request
)
93 self
.cursor
.execute
.side_effect
= _execute
94 self
.cursor
.__iter
__.return_value
= self
.rows
96 self
.db_patch
= mock
.patch("psycopg2.connect")
97 self
.db_mock
= self
.db_patch
.start()
98 self
.db_mock
.return_value
= connect_mock
100 def check_calls(self
):
102 self
.tester
.assertEqual(1 + len(self
.rows
), self
.db_mock
.call_count
)
103 self
.tester
.assertEqual(1 + len(self
.rows
) + self
.total_report_lines
, self
.cursor
.execute
.call_count
)
105 self
.tester
.assertEqual(1, self
.db_mock
.call_count
)
106 self
.tester
.assertEqual(1, self
.cursor
.execute
.call_count
)
112 def __init__(self
, tester
):
114 self
.reports
= tester
.requests_by_market()
117 self
.error_calls
= []
118 self
.mocker
= requests_mock
.Mocker()
119 def not_stubbed(*args
):
120 self
.error_calls
.append([args
[0].method
, args
[0].url
])
121 raise requests_mock
.exceptions
.MockException("Not stubbed URL")
122 self
.mocker
.register_uri(requests_mock
.ANY
, requests_mock
.ANY
,
127 for market_id
, elements
in self
.reports
.items():
128 for element
in elements
:
129 method
= element
["method"]
132 .setdefault((method
, url
), {}) \
133 .setdefault(market_id
, []) \
136 for ((method
, url
), elements
) in self
.mocks
.items():
137 self
.mocker
.register_uri(method
, url
, text
=functools
.partial(callback
, self
, elements
), complete_qs
=True)
145 def check_calls(self
):
146 self
.tester
.assertEqual([], self
.error_calls
)
147 for (method
, url
), elements
in self
.mocks
.items():
148 for market_id
, element
in elements
.items():
149 self
.tester
.assertEqual(0, len(element
), "Missing calls to {} {}, market_id {}".format(method
, url
, market_id
))
151 def clean_body(self
, body
):
154 if isinstance(body
, bytes):
156 body
= re
.sub(r
"&nonce=\d*$", "", body
)
157 body
= re
.sub(r
"nonce=\d*&?", "", body
)
160 def callback(self
, elements
, request
, context
):
162 element
= elements
[request
.headers
.get("X-market-id")].pop(0)
163 except (IndexError, KeyError):
164 self
.error_calls
.append([request
.method
, request
.url
, request
.headers
.get("X-market-id")])
165 raise RuntimeError("Unexpected call")
166 if element
["response"] is None and element
["response_same_as"] is not None:
167 element
["response"] = self
.last_https
[element
["response_same_as"]]
168 elif element
["response"] is not None:
169 self
.last_https
[element
["date"]] = element
["response"]
171 assert self
.clean_body(request
.body
) == \
172 self
.clean_body(element
["body"]), "Body does not match"
173 context
.status_code
= element
["status"]
174 if "error" in element
:
175 if element
["error"] == "SSLError":
176 raise SSLError(element
["error_message"])
178 raise getattr(requests
.exceptions
, element
["error"])(element
["error_message"])
179 return element
["response"]
181 class GlobalVariablesMock
:
187 mock
.patch
.multiple(market
.Portfolio
,
188 data
=store
.LockedVar(None),
189 liquidities
=store
.LockedVar({}),
190 last_date
=store
.LockedVar(None),
191 report
=store
.LockedVar(store
.ReportStore(None, no_http_dup
=True)),
195 worker_started
=False,
198 for patcher
in self
.patchers
:
207 true_time
= time
.time
208 true_sleep
= time
.sleep
210 datetime_patch
= None
213 def start(cls
, start_date
):
214 cls
.delta
= (datetime
.datetime
.now() - start_date
).total_seconds()
216 class fake_datetime(datetime
.datetime
):
218 def now(cls
, tz
=None):
220 return cls
.fromtimestamp(time
.time())
222 return tz
.fromutc(cls
.utcfromtimestamp(time
.time()).replace(tzinfo
=tz
))
224 cls
.time_patch
= mock
.patch
.multiple(time
, time
=cls
.fake_time
, sleep
=cls
.fake_sleep
)
225 cls
.datetime_patch
= mock
.patch
.multiple(datetime
, datetime
=fake_datetime
)
226 cls
.time_patch
.start()
227 cls
.datetime_patch
.start()
232 cls
.datetime_patch
.stop()
233 cls
.time_patch
.stop()
237 return cls
.true_time() - cls
.delta
240 def fake_sleep(cls
, duration
):
241 cls
.delta
-= duration
244 class AcceptanceTestCase():
245 def parse_file(self
, report_file
):
246 with open(report_file
, "rb") as f
:
247 json_content
= json
.load(f
, parse_float
=Decimal
)
248 config
, user
, date
, market_id
= self
.parse_config(json_content
)
249 http_requests
= self
.parse_requests(json_content
)
251 return config
, user
, date
, market_id
, http_requests
, json_content
253 def parse_requests(self
, json_content
):
255 for element
in json_content
:
256 if element
["type"] != "http_request":
258 http_requests
.append(element
)
261 def parse_config(self
, json_content
):
263 for element
in json_content
:
264 if element
["type"] != "market":
266 market_info
= element
267 assert market_info
is not None, "Couldn't find market element"
269 args
= market_info
["args"]
271 for arg
in ["before", "after", "quiet", "debug"]:
272 if args
.get(arg
, False):
273 config
.append("--{}".format(arg
))
274 for arg
in ["parallel", "report_db"]:
275 if not args
.get(arg
, False):
276 config
.append("--no-{}".format(arg
.replace("_", "-")))
277 for action
in (args
.get("action", []) or []):
278 config
.extend(["--action", action
])
279 if args
.get("report_path") is not None:
280 config
.extend(["--report-path", args
.get("report_path")])
281 if args
.get("user") is not None:
282 config
.extend(["--user", args
.get("user")])
283 config
.extend(["--config", ""])
285 user
= market_info
["user_id"]
286 date
= datetime
.datetime
.strptime(market_info
["date"], "%Y-%m-%dT%H:%M:%S.%f")
287 market_id
= market_info
["market_id"]
288 return config
, user
, date
, market_id
290 def requests_by_market(self
):
295 for user_id
, market_id
, http_requests
, report_lines
in self
.reports
.values():
296 r
[str(market_id
)] = []
297 for http_request
in http_requests
:
298 if http_request
["market_id"] is None:
300 r
[None].append(http_request
)
302 r
[str(market_id
)].append(http_request
)
307 if not hasattr(self
, "files"):
308 raise "This class expects to be inherited with a class defining 'files' variable"
309 if not hasattr(self
, "log_files"):
313 self
.start_date
= datetime
.datetime
.now()
316 self
.config
, user_id
, date
, market_id
, http_requests
, report_lines
= self
.parse_file(f
)
317 if date
< self
.start_date
:
318 self
.start_date
= date
319 self
.reports
[f
] = [user_id
, market_id
, http_requests
, report_lines
]
321 self
.database_mock
= DatabaseMock(self
, self
.reports
, "--no-report-db" not in self
.config
)
322 self
.requests_mock
= RequestsMock(self
)
323 self
.file_mock
= FileMock(self
.log_files
, "--quiet" in self
.config
, self
)
324 self
.global_variables_mock
= GlobalVariablesMock()
326 self
.database_mock
.start()
327 self
.requests_mock
.start()
328 self
.file_mock
.start()
329 self
.global_variables_mock
.start()
330 TimeMock
.start(self
.start_date
)
333 main
.main(self
.config
)
334 self
.requests_mock
.check_calls()
335 self
.database_mock
.check_calls()
336 self
.file_mock
.check_calls()
340 self
.global_variables_mock
.stop()
341 self
.file_mock
.stop()
342 self
.requests_mock
.stop()
343 self
.database_mock
.stop()
345 for dirfile
in glob
.glob("tests/acceptance/**/*/", recursive
=True):
346 json_files
= glob
.glob("{}/*.json".format(dirfile
))
347 log_files
= glob
.glob("{}/*.log".format(dirfile
))
348 if len(json_files
) > 0:
349 name
= dirfile
.replace("tests/acceptance/", "").replace("/", "_")[0:-1]
350 cname
= "".join(list(map(lambda x
: x
.capitalize(), name
.split("_"))))
352 globals()[cname
] = type(cname
,
353 (AcceptanceTestCase
,unittest
.TestCase
), {
354 "log_files": log_files
,
356 "test_{}".format(name
): AcceptanceTestCase
.base_test
359 if __name__
== '__main__':