5 from unittest
import mock
6 from ssl
import SSLError
7 from decimal
import Decimal
8 import simplejson
as json
10 from io
import StringIO
19 true_sleep
= time
.sleep
24 def travel(cls
, start_date
):
26 cls
.delta_init
= (datetime
.datetime
.now() - start_date
).total_seconds()
33 class fake_datetime(datetime
.datetime
):
35 def now(cls
, tz
=None):
37 return cls
.fromtimestamp(time
.time())
39 return tz
.fromutc(cls
.utcfromtimestamp(time
.time()).replace(tzinfo
=tz
))
41 cls
.time_patch
= mock
.patch
.multiple(time
, time
=cls
.fake_time
, sleep
=cls
.fake_sleep
)
42 cls
.datetime_patch
= mock
.patch
.multiple(datetime
, datetime
=fake_datetime
)
43 cls
.time_patch
.start()
44 cls
.datetime_patch
.start()
53 cls
.delta
.setdefault(threading
.current_thread(), cls
.delta_init
)
54 return cls
.true_time() - cls
.delta
[threading
.current_thread()]
57 def fake_sleep(cls
, duration
):
58 cls
.delta
.setdefault(threading
.current_thread(), cls
.delta_init
)
59 cls
.delta
[threading
.current_thread()] -= float(duration
)
60 cls
.true_sleep(min(float(duration
), 0.1))
66 def __init__(self
, log_files
, quiet
, tester
):
69 if log_files
is not None and len(log_files
) > 0:
70 self
.read_log_files(log_files
)
73 mock
.patch("market.open"),
74 mock
.patch("os.makedirs"),
75 mock
.patch("sys.stdout", new_callable
=StringIO
),
80 for patch
in self
.patches
:
81 self
.mocks
.append(patch
.start())
82 self
.stdout
= self
.mocks
[-1]
84 def check_calls(self
):
85 stdout
= self
.stdout
.getvalue()
87 self
.tester
.assertEqual("", stdout
)
89 log
= self
.strip_log(stdout
)
90 if len(self
.log_files
) != 0:
91 split_logs
= log
.split("\n")
92 self
.tester
.assertEqual(sum(len(f
) for f
in self
.log_files
), len(split_logs
))
94 for log_file
in self
.log_files
:
96 split_logs
.pop(split_logs
.index(line
))
98 if not line
.startswith("[Worker] "):
99 self
.tester
.fail("« {} » not found in log file {}".format(line
, split_logs
))
100 # Le fichier de log est écrit
101 # Le rapport est écrit si pertinent
102 # Le rapport contient le bon nombre de lignes
105 for patch
in self
.patches
[::-1]:
109 def strip_log(self
, log
):
110 log
= log
.replace("\n\n", "\n")
111 return re
.sub(r
"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log
, flags
=re
.MULTILINE
)
113 def read_log_files(self
, log_files
):
114 for log_file
in log_files
:
115 with open(log_file
, "r") as f
:
116 log
= self
.strip_log(f
.read()).split("\n")
117 if len(log
[-1]) == 0:
119 self
.log_files
.append(log
)
122 def __init__(self
, tester
, reports
, report_db
):
124 self
.reports
= reports
125 self
.report_db
= report_db
127 self
.total_report_lines
= 0
129 for user_id
, market_id
, http_requests
, report_lines
in self
.reports
.values():
130 self
.rows
.append( (market_id
, { "key": "key", "secret": "secret" }
, user_id
) )
131 self
.total_report_lines
+= len(report_lines
)
134 connect_mock
= mock
.Mock()
135 self
.cursor
= mock
.MagicMock()
136 connect_mock
.cursor
.return_value
= self
.cursor
137 def _execute(request
, *args
):
138 self
.requests
.append(request
)
139 self
.cursor
.execute
.side_effect
= _execute
140 self
.cursor
.__iter
__.return_value
= self
.rows
142 self
.db_patch
= mock
.patch("psycopg2.connect")
143 self
.db_mock
= self
.db_patch
.start()
144 self
.db_mock
.return_value
= connect_mock
146 def check_calls(self
):
148 self
.tester
.assertEqual(1 + len(self
.rows
), self
.db_mock
.call_count
)
149 self
.tester
.assertEqual(1 + len(self
.rows
) + self
.total_report_lines
, self
.cursor
.execute
.call_count
)
151 self
.tester
.assertEqual(1, self
.db_mock
.call_count
)
152 self
.tester
.assertEqual(1, self
.cursor
.execute
.call_count
)
158 def __init__(self
, tester
):
160 self
.reports
= tester
.requests_by_market()
163 self
.error_calls
= []
164 self
.mocker
= requests_mock
.Mocker()
165 def not_stubbed(*args
):
166 self
.error_calls
.append([args
[0].method
, args
[0].url
])
167 raise requests_mock
.exceptions
.MockException("Not stubbed URL")
168 self
.mocker
.register_uri(requests_mock
.ANY
, requests_mock
.ANY
,
173 for market_id
, elements
in self
.reports
.items():
174 for element
in elements
:
175 method
= element
["method"]
178 .setdefault((method
, url
), {}) \
179 .setdefault(market_id
, []) \
182 for ((method
, url
), elements
) in self
.mocks
.items():
183 self
.mocker
.register_uri(method
, url
, text
=functools
.partial(callback
, self
, elements
), complete_qs
=True)
192 "https://cryptoportfolio.io/wp-content/uploads/portfolio/json/cryptoportfolio.json",
193 "https://poloniex.com/public?command=returnTicker",
195 def check_calls(self
):
196 self
.tester
.assertEqual([], self
.error_calls
)
197 for (method
, url
), elements
in self
.mocks
.items():
198 for market_id
, element
in elements
.items():
199 if url
not in self
.lazy_calls
:
200 self
.tester
.assertEqual(0, len(element
), "Missing calls to {} {}, market_id {}".format(method
, url
, market_id
))
202 def clean_body(self
, body
):
205 if isinstance(body
, bytes):
207 body
= re
.sub(r
"&nonce=\d*$", "", body
)
208 body
= re
.sub(r
"nonce=\d*&?", "", body
)
211 def callback(self
, elements
, request
, context
):
213 element
= elements
[request
.headers
.get("X-market-id")].pop(0)
214 except (IndexError, KeyError):
215 self
.error_calls
.append([request
.method
, request
.url
, request
.headers
.get("X-market-id")])
216 raise RuntimeError("Unexpected call")
217 if element
["response"] is None and element
["response_same_as"] is not None:
218 element
["response"] = self
.last_https
[element
["response_same_as"]]
219 elif element
["response"] is not None:
220 self
.last_https
[element
["date"]] = element
["response"]
222 time
.sleep(element
.get("duration", 0))
224 assert self
.clean_body(request
.body
) == \
225 self
.clean_body(element
["body"]), "Body does not match"
226 context
.status_code
= element
["status"]
227 if "error" in element
:
228 if element
["error"] == "SSLError":
229 raise SSLError(element
["error_message"])
231 raise getattr(requests
.exceptions
, element
["error"])(element
["error_message"])
232 return element
["response"]
234 class GlobalVariablesMock
:
240 mock
.patch
.multiple(market
.Portfolio
,
241 data
=store
.LockedVar(None),
242 liquidities
=store
.LockedVar({}),
243 last_date
=store
.LockedVar(None),
244 report
=store
.LockedVar(store
.ReportStore(None, no_http_dup
=True)),
248 worker_started
=False,
251 for patcher
in self
.patchers
:
258 class AcceptanceTestCase():
259 def parse_file(self
, report_file
):
260 with open(report_file
, "rb") as f
:
261 json_content
= json
.load(f
, parse_float
=Decimal
)
262 config
, user
, date
, market_id
= self
.parse_config(json_content
)
263 http_requests
= self
.parse_requests(json_content
)
265 return config
, user
, date
, market_id
, http_requests
, json_content
267 def parse_requests(self
, json_content
):
269 for element
in json_content
:
270 if element
["type"] != "http_request":
272 http_requests
.append(element
)
275 def parse_config(self
, json_content
):
277 for element
in json_content
:
278 if element
["type"] != "market":
280 market_info
= element
281 assert market_info
is not None, "Couldn't find market element"
283 args
= market_info
["args"]
285 for arg
in ["before", "after", "quiet", "debug"]:
286 if args
.get(arg
, False):
287 config
.append("--{}".format(arg
))
288 for arg
in ["parallel", "report_db"]:
289 if not args
.get(arg
, False):
290 config
.append("--no-{}".format(arg
.replace("_", "-")))
291 for action
in (args
.get("action", []) or []):
292 config
.extend(["--action", action
])
293 if args
.get("report_path") is not None:
294 config
.extend(["--report-path", args
.get("report_path")])
295 if args
.get("user") is not None:
296 config
.extend(["--user", args
.get("user")])
297 config
.extend(["--config", ""])
299 user
= market_info
["user_id"]
300 date
= datetime
.datetime
.strptime(market_info
["date"], "%Y-%m-%dT%H:%M:%S.%f")
301 market_id
= market_info
["market_id"]
302 return config
, user
, date
, market_id
304 def requests_by_market(self
):
309 for user_id
, market_id
, http_requests
, report_lines
in self
.reports
.values():
310 r
[str(market_id
)] = []
311 for http_request
in http_requests
:
312 if http_request
["market_id"] is None:
314 r
[None].append(http_request
)
316 r
[str(market_id
)].append(http_request
)
321 if not hasattr(self
, "files"):
322 raise "This class expects to be inherited with a class defining 'files' variable"
323 if not hasattr(self
, "log_files"):
327 self
.start_date
= datetime
.datetime
.now()
330 self
.config
, user_id
, date
, market_id
, http_requests
, report_lines
= self
.parse_file(f
)
331 if date
< self
.start_date
:
332 self
.start_date
= date
333 self
.reports
[f
] = [user_id
, market_id
, http_requests
, report_lines
]
335 self
.database_mock
= DatabaseMock(self
, self
.reports
, "--no-report-db" not in self
.config
)
336 self
.requests_mock
= RequestsMock(self
)
337 self
.file_mock
= FileMock(self
.log_files
, "--quiet" in self
.config
, self
)
338 self
.global_variables_mock
= GlobalVariablesMock()
340 self
.database_mock
.start()
341 self
.requests_mock
.start()
342 self
.file_mock
.start()
343 self
.global_variables_mock
.start()
344 TimeMock
.travel(self
.start_date
)
347 main
.main(self
.config
)
348 self
.requests_mock
.check_calls()
349 self
.database_mock
.check_calls()
350 self
.file_mock
.check_calls()
354 self
.global_variables_mock
.stop()
355 self
.file_mock
.stop()
356 self
.requests_mock
.stop()
357 self
.database_mock
.stop()