6 from unittest
import mock
7 from ssl
import SSLError
8 from decimal
import Decimal
9 import simplejson
as json
11 from io
import StringIO
16 cls
.file_mock
= mock
.patch("market.open")
17 cls
.os_mock
= mock
.patch("os.makedirs")
18 cls
.stdout_mock
= mock
.patch('sys.stdout', new_callable
=StringIO
)
19 cls
.stdout_mock
.start()
24 def check_calls(cls
, tester
):
26 #raise NotImplementedError("Todo")
31 cls
.stdout_mock
.stop()
39 total_report_lines
= 0
44 def start(cls
, reports
, report_db
):
45 cls
.report_db
= report_db
47 cls
.total_report_lines
= 0
48 for user_id
, market_id
, http_requests
, report_lines
in reports
.values():
49 cls
.rows
.append( (market_id
, { "key": "key", "secret": "secret" }
, user_id
) )
50 cls
.total_report_lines
+= len(report_lines
)
52 connect_mock
= mock
.Mock()
53 cls
.cursor
= mock
.MagicMock()
54 connect_mock
.cursor
.return_value
= cls
.cursor
55 def _execute(request
, *args
):
56 cls
.requests
.append(request
)
57 cls
.cursor
.execute
.side_effect
= _execute
58 cls
.cursor
.__iter
__.return_value
= cls
.rows
60 cls
.db_patch
= mock
.patch("psycopg2.connect")
61 cls
.db_mock
= cls
.db_patch
.start()
62 cls
.db_mock
.return_value
= connect_mock
65 def check_calls(cls
, tester
):
67 tester
.assertEqual(1 + len(cls
.rows
), cls
.db_mock
.call_count
)
68 tester
.assertEqual(1 + len(cls
.rows
) + cls
.total_report_lines
, cls
.cursor
.execute
.call_count
)
70 tester
.assertEqual(1, cls
.db_mock
.call_count
)
71 tester
.assertEqual(1, cls
.cursor
.execute
.call_count
)
78 cls
.total_report_lines
= 0
90 def start(cls
, reports
):
91 cls
.adapter
= requests_mock
.Adapter()
92 cls
.adapter
.register_uri(requests_mock
.ANY
, requests_mock
.ANY
,
93 exc
=requests_mock
.exceptions
.MockException("Not stubbed URL"))
94 true_session
= requests
.Session
98 for market_id
, elements
in reports
.items():
99 for element
in elements
:
100 method
= element
["method"]
103 .setdefault((method
, url
), {}) \
104 .setdefault(market_id
, []) \
107 for ((method
, url
), elements
) in cls
.mocks
.items():
108 cls
.adapter
.register_uri(method
, url
, text
=cls
.callback_func(elements
), complete_qs
=True)
110 session
= true_session()
111 session
.get_adapter
= lambda url
: cls
.adapter
113 cls
.request_patch
= [
114 mock
.patch
.object(requests
.sessions
, "Session", new
=_session
),
115 mock
.patch
.object(requests
, "Session", new
=_session
)
117 for patch
in cls
.request_patch
:
122 for patch
in cls
.request_patch
:
124 cls
.request_patch
= []
130 def check_calls(cls
, tester
):
131 for (method
, url
), elements
in cls
.mocks
.items():
132 for market_id
, element
in elements
.items():
133 tester
.assertEqual(0, len(element
), "Missing calls to {} {}, market_id {}".format(method
, url
, market_id
))
136 def clean_body(cls
, body
):
140 if isinstance(body
, bytes):
142 body
= re
.sub(r
"&nonce=\d*$", "", body
)
143 body
= re
.sub(r
"nonce=\d*&?", "", body
)
147 def callback_func(cls
, elements
):
148 def callback(request
, context
):
150 element
= elements
[request
.headers
.get("X-market-id")].pop(0)
151 except (IndexError, KeyError):
152 raise RuntimeError("Unexpected call")
153 if element
["response"] is None and element
["response_same_as"] is not None:
154 element
["response"] = cls
.last_https
[element
["response_same_as"]]
155 elif element
["response"] is not None:
156 cls
.last_https
[element
["date"]] = element
["response"]
158 assert cls
.clean_body(request
.body
) == \
159 cls
.clean_body(element
["body"]), "Body does not match"
160 context
.status_code
= element
["status"]
161 if "error" in element
:
162 if element
["error"] == "SSLError":
163 raise SSLError(element
["error_message"])
165 raise getattr(requests
.exceptions
, element
["error"])(element
["error_message"])
166 return element
["response"]
171 true_time
= time
.time
173 datetime_patch
= None
176 def start(cls
, start_date
):
177 cls
.delta
= (datetime
.datetime
.now() - start_date
).total_seconds()
179 class fake_datetime(datetime
.datetime
):
181 def now(cls
, tz
=None):
183 return cls
.fromtimestamp(time
.time())
185 return tz
.fromutc(cls
.utcfromtimestamp(time
.time()).replace(tzinfo
=tz
))
187 cls
.time_patch
= mock
.patch
.multiple(time
, time
=cls
.fake_time
, sleep
=cls
.fake_sleep
)
188 cls
.datetime_patch
= mock
.patch
.multiple(datetime
, datetime
=fake_datetime
)
189 cls
.time_patch
.start()
190 cls
.datetime_patch
.start()
195 cls
.datetime_patch
.stop()
196 cls
.time_patch
.stop()
200 return cls
.true_time() - cls
.delta
203 def fake_sleep(cls
, duration
):
204 cls
.delta
-= duration
206 class AcceptanceTestCase():
207 def parse_file(self
, report_file
):
208 with open(report_file
, "rb") as f
:
209 json_content
= json
.load(f
, parse_float
=Decimal
)
210 config
, user
, date
, market_id
= self
.parse_config(json_content
)
211 http_requests
= self
.parse_requests(json_content
)
213 return config
, user
, date
, market_id
, http_requests
, json_content
215 def parse_requests(self
, json_content
):
217 for element
in json_content
:
218 if element
["type"] != "http_request":
220 http_requests
.append(element
)
223 def parse_config(self
, json_content
):
225 for element
in json_content
:
226 if element
["type"] != "market":
228 market_info
= element
229 assert market_info
is not None, "Couldn't find market element"
231 args
= market_info
["args"]
233 for arg
in ["before", "after", "quiet", "debug"]:
234 if args
.get(arg
, False):
235 config
.append("--{}".format(arg
))
236 for arg
in ["parallel", "report_db"]:
237 if not args
.get(arg
, False):
238 config
.append("--no-{}".format(arg
.replace("_", "-")))
239 for action
in args
.get("action", []):
240 config
.extend(["--action", action
])
241 if args
.get("report_path") is not None:
242 config
.extend(["--report-path", args
.get("report_path")])
243 if args
.get("user") is not None:
244 config
.extend(["--user", args
.get("user")])
245 config
.extend(["--config", ""])
247 user
= market_info
["user_id"]
248 date
= datetime
.datetime
.strptime(market_info
["date"], "%Y-%m-%dT%H:%M:%S.%f")
249 market_id
= market_info
["market_id"]
250 return config
, user
, date
, market_id
252 def requests_by_market(self
):
257 for user_id
, market_id
, http_requests
, report_lines
in self
.reports
.values():
258 r
[str(market_id
)] = []
259 for http_request
in http_requests
:
260 if http_request
["market_id"] is None:
262 r
[None].append(http_request
)
264 r
[str(market_id
)].append(http_request
)
269 if not hasattr(self
, "files"):
270 raise "This class expects to be inherited with a class defining self.files in setUp"
273 self
.start_date
= datetime
.datetime
.now()
276 self
.config
, user_id
, date
, market_id
, http_requests
, report_lines
= self
.parse_file(f
)
277 if date
< self
.start_date
:
278 self
.start_date
= date
279 self
.reports
[f
] = [user_id
, market_id
, http_requests
, report_lines
]
281 DatabaseMock
.start(self
.reports
, "--no-report-db" not in self
.config
)
282 RequestsMock
.start(self
.requests_by_market())
284 TimeMock
.start(self
.start_date
)
288 main
.main(self
.config
)
289 RequestsMock
.check_calls(self
)
290 DatabaseMock
.check_calls(self
)
291 FileMock
.check_calls(self
)
300 for dirfile
in glob
.glob("tests/acceptance/**/*/", recursive
=True):
301 json_files
= glob
.glob("{}/*.json".format(dirfile
))
302 if len(json_files
) > 0:
303 name
= dirfile
.replace("tests/acceptance/", "").replace("/", "_")[0:-1]
304 cname
= "".join(list(map(lambda x
: x
.capitalize(), name
.split("_"))))
306 globals()[cname
] = type(cname
,
307 (AcceptanceTestCase
,unittest
.TestCase
), {
309 "test_{}".format(name
): AcceptanceTestCase
.base_test
312 if __name__
== '__main__':