aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/acceptance.py358
-rw-r--r--tests/helper.py12
-rw-r--r--tests/test_acceptance.py25
-rw-r--r--tests/test_ccxt_wrapper.py3
-rw-r--r--tests/test_main.py1
-rw-r--r--tests/test_market.py2
-rw-r--r--tests/test_portfolio.py6
-rw-r--r--tests/test_store.py6
8 files changed, 411 insertions, 2 deletions
diff --git a/tests/acceptance.py b/tests/acceptance.py
new file mode 100644
index 0000000..66014ca
--- /dev/null
+++ b/tests/acceptance.py
@@ -0,0 +1,358 @@
1import requests
2import requests_mock
3import sys, os
4import time, datetime
5from unittest import mock
6from ssl import SSLError
7from decimal import Decimal
8import simplejson as json
9import psycopg2
10from io import StringIO
11import re
12import functools
13import threading
14
15class TimeMock:
16 delta = {}
17 delta_init = 0
18 true_time = time.time
19 true_sleep = time.sleep
20 time_patch = None
21 datetime_patch = None
22
23 @classmethod
24 def travel(cls, start_date):
25 cls.delta = {}
26 cls.delta_init = (datetime.datetime.now() - start_date).total_seconds()
27
28 @classmethod
29 def start(cls):
30 cls.delta = {}
31 cls.delta_init = 0
32
33 class fake_datetime(datetime.datetime):
34 @classmethod
35 def now(cls, tz=None):
36 if tz is None:
37 return cls.fromtimestamp(time.time())
38 else:
39 return tz.fromutc(cls.utcfromtimestamp(time.time()).replace(tzinfo=tz))
40
41 cls.time_patch = mock.patch.multiple(time, time=cls.fake_time, sleep=cls.fake_sleep)
42 cls.datetime_patch = mock.patch.multiple(datetime, datetime=fake_datetime)
43 cls.time_patch.start()
44 cls.datetime_patch.start()
45
46 @classmethod
47 def stop(cls):
48 cls.delta = {}
49 cls.delta_init = 0
50
51 @classmethod
52 def fake_time(cls):
53 cls.delta.setdefault(threading.current_thread(), cls.delta_init)
54 return cls.true_time() - cls.delta[threading.current_thread()]
55
56 @classmethod
57 def fake_sleep(cls, duration):
58 cls.delta.setdefault(threading.current_thread(), cls.delta_init)
59 cls.delta[threading.current_thread()] -= float(duration)
60 cls.true_sleep(min(float(duration), 0.1))
61
62TimeMock.start()
63import main
64
65class FileMock:
66 def __init__(self, log_files, quiet, tester):
67 self.tester = tester
68 self.log_files = []
69 if log_files is not None and len(log_files) > 0:
70 self.read_log_files(log_files)
71 self.quiet = quiet
72 self.patches = [
73 mock.patch("market.open"),
74 mock.patch("os.makedirs"),
75 mock.patch("sys.stdout", new_callable=StringIO),
76 ]
77 self.mocks = []
78
79 def start(self):
80 for patch in self.patches:
81 self.mocks.append(patch.start())
82 self.stdout = self.mocks[-1]
83
84 def check_calls(self):
85 stdout = self.stdout.getvalue()
86 if self.quiet:
87 self.tester.assertEqual("", stdout)
88 else:
89 log = self.strip_log(stdout)
90 if len(self.log_files) != 0:
91 split_logs = log.split("\n")
92 self.tester.assertEqual(sum(len(f) for f in self.log_files), len(split_logs))
93 try:
94 for log_file in self.log_files:
95 for line in log_file:
96 split_logs.pop(split_logs.index(line))
97 except ValueError:
98 if not line.startswith("[Worker] "):
99 self.tester.fail("« {} » not found in log file {}".format(line, split_logs))
100 # Le fichier de log est écrit
101 # Le rapport est écrit si pertinent
102 # Le rapport contient le bon nombre de lignes
103
104 def stop(self):
105 for patch in self.patches[::-1]:
106 patch.stop()
107 self.mocks.pop()
108
109 def strip_log(self, log):
110 log = log.replace("\n\n", "\n")
111 return re.sub(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}: ", "", log, flags=re.MULTILINE)
112
113 def read_log_files(self, log_files):
114 for log_file in log_files:
115 with open(log_file, "r") as f:
116 log = self.strip_log(f.read()).split("\n")
117 if len(log[-1]) == 0:
118 log.pop()
119 self.log_files.append(log)
120
121class DatabaseMock:
122 def __init__(self, tester, reports, report_db):
123 self.tester = tester
124 self.reports = reports
125 self.report_db = report_db
126 self.rows = []
127 self.total_report_lines= 0
128 self.requests = []
129 for user_id, market_id, http_requests, report_lines in self.reports.values():
130 self.rows.append( (market_id, { "key": "key", "secret": "secret" }, user_id) )
131 self.total_report_lines += len(report_lines)
132
133 def start(self):
134 connect_mock = mock.Mock()
135 self.cursor = mock.MagicMock()
136 connect_mock.cursor.return_value = self.cursor
137 def _execute(request, *args):
138 self.requests.append(request)
139 self.cursor.execute.side_effect = _execute
140 self.cursor.__iter__.return_value = self.rows
141
142 self.db_patch = mock.patch("psycopg2.connect")
143 self.db_mock = self.db_patch.start()
144 self.db_mock.return_value = connect_mock
145
146 def check_calls(self):
147 if self.report_db:
148 self.tester.assertEqual(1 + len(self.rows), self.db_mock.call_count)
149 self.tester.assertEqual(1 + len(self.rows) + self.total_report_lines, self.cursor.execute.call_count)
150 else:
151 self.tester.assertEqual(1, self.db_mock.call_count)
152 self.tester.assertEqual(1, self.cursor.execute.call_count)
153
154 def stop(self):
155 self.db_patch.stop()
156
157class RequestsMock:
158 def __init__(self, tester):
159 self.tester = tester
160 self.reports = tester.requests_by_market()
161
162 self.last_https = {}
163 self.error_calls = []
164 self.mocker = requests_mock.Mocker()
165 def not_stubbed(*args):
166 self.error_calls.append([args[0].method, args[0].url])
167 raise requests_mock.exceptions.MockException("Not stubbed URL")
168 self.mocker.register_uri(requests_mock.ANY, requests_mock.ANY,
169 text=not_stubbed)
170
171 self.mocks = {}
172
173 for market_id, elements in self.reports.items():
174 for element in elements:
175 method = element["method"]
176 url = element["url"]
177 self.mocks \
178 .setdefault((method, url), {}) \
179 .setdefault(market_id, []) \
180 .append(element)
181
182 for ((method, url), elements) in self.mocks.items():
183 self.mocker.register_uri(method, url, text=functools.partial(callback, self, elements), complete_qs=True)
184
185 def start(self):
186 self.mocker.start()
187
188 def stop(self):
189 self.mocker.stop()
190
191 lazy_calls = [
192 "https://cryptoportfolio.io/wp-content/uploads/portfolio/json/cryptoportfolio.json",
193 "https://poloniex.com/public?command=returnTicker",
194 ]
195 def check_calls(self):
196 self.tester.assertEqual([], self.error_calls)
197 for (method, url), elements in self.mocks.items():
198 for market_id, element in elements.items():
199 if url not in self.lazy_calls:
200 self.tester.assertEqual(0, len(element), "Missing calls to {} {}, market_id {}".format(method, url, market_id))
201
202 def clean_body(self, body):
203 if body is None:
204 return None
205 if isinstance(body, bytes):
206 body = body.decode()
207 body = re.sub(r"&nonce=\d*$", "", body)
208 body = re.sub(r"nonce=\d*&?", "", body)
209 return body
210
211def callback(self, elements, request, context):
212 try:
213 element = elements[request.headers.get("X-market-id")].pop(0)
214 except (IndexError, KeyError):
215 self.error_calls.append([request.method, request.url, request.headers.get("X-market-id")])
216 raise RuntimeError("Unexpected call")
217 if element["response"] is None and element["response_same_as"] is not None:
218 element["response"] = self.last_https[element["response_same_as"]]
219 elif element["response"] is not None:
220 self.last_https[element["date"]] = element["response"]
221
222 time.sleep(element.get("duration", 0))
223
224 assert self.clean_body(request.body) == \
225 self.clean_body(element["body"]), "Body does not match"
226 context.status_code = element["status"]
227 if "error" in element:
228 if element["error"] == "SSLError":
229 raise SSLError(element["error_message"])
230 else:
231 raise getattr(requests.exceptions, element["error"])(element["error_message"])
232 return element["response"]
233
234class GlobalVariablesMock:
235 def start(self):
236 import market
237 import store
238
239 self.patchers = [
240 mock.patch.multiple(market.Portfolio,
241 data=store.LockedVar(None),
242 liquidities=store.LockedVar({}),
243 last_date=store.LockedVar(None),
244 report=store.LockedVar(store.ReportStore(None, no_http_dup=True)),
245 worker=None,
246 worker_tag="",
247 worker_notify=None,
248 worker_started=False,
249 callback=None)
250 ]
251 for patcher in self.patchers:
252 patcher.start()
253
254 def stop(self):
255 pass
256
257
258class AcceptanceTestCase():
259 def parse_file(self, report_file):
260 with open(report_file, "rb") as f:
261 json_content = json.load(f, parse_float=Decimal)
262 config, user, date, market_id = self.parse_config(json_content)
263 http_requests = self.parse_requests(json_content)
264
265 return config, user, date, market_id, http_requests, json_content
266
267 def parse_requests(self, json_content):
268 http_requests = []
269 for element in json_content:
270 if element["type"] != "http_request":
271 continue
272 http_requests.append(element)
273 return http_requests
274
275 def parse_config(self, json_content):
276 market_info = None
277 for element in json_content:
278 if element["type"] != "market":
279 continue
280 market_info = element
281 assert market_info is not None, "Couldn't find market element"
282
283 args = market_info["args"]
284 config = []
285 for arg in ["before", "after", "quiet", "debug"]:
286 if args.get(arg, False):
287 config.append("--{}".format(arg))
288 for arg in ["parallel", "report_db"]:
289 if not args.get(arg, False):
290 config.append("--no-{}".format(arg.replace("_", "-")))
291 for action in (args.get("action", []) or []):
292 config.extend(["--action", action])
293 if args.get("report_path") is not None:
294 config.extend(["--report-path", args.get("report_path")])
295 if args.get("user") is not None:
296 config.extend(["--user", args.get("user")])
297 config.extend(["--config", ""])
298
299 user = market_info["user_id"]
300 date = datetime.datetime.strptime(market_info["date"], "%Y-%m-%dT%H:%M:%S.%f")
301 market_id = market_info["market_id"]
302 return config, user, date, market_id
303
304 def requests_by_market(self):
305 r = {
306 None: []
307 }
308 got_common = False
309 for user_id, market_id, http_requests, report_lines in self.reports.values():
310 r[str(market_id)] = []
311 for http_request in http_requests:
312 if http_request["market_id"] is None:
313 if not got_common:
314 r[None].append(http_request)
315 else:
316 r[str(market_id)].append(http_request)
317 got_common = True
318 return r
319
320 def setUp(self):
321 if not hasattr(self, "files"):
322 raise "This class expects to be inherited with a class defining 'files' variable"
323 if not hasattr(self, "log_files"):
324 self.log_files = []
325
326 self.reports = {}
327 self.start_date = datetime.datetime.now()
328 self.config = []
329 for f in self.files:
330 self.config, user_id, date, market_id, http_requests, report_lines = self.parse_file(f)
331 if date < self.start_date:
332 self.start_date = date
333 self.reports[f] = [user_id, market_id, http_requests, report_lines]
334
335 self.database_mock = DatabaseMock(self, self.reports, "--no-report-db" not in self.config)
336 self.requests_mock = RequestsMock(self)
337 self.file_mock = FileMock(self.log_files, "--quiet" in self.config, self)
338 self.global_variables_mock = GlobalVariablesMock()
339
340 self.database_mock.start()
341 self.requests_mock.start()
342 self.file_mock.start()
343 self.global_variables_mock.start()
344 TimeMock.travel(self.start_date)
345
346 def base_test(self):
347 main.main(self.config)
348 self.requests_mock.check_calls()
349 self.database_mock.check_calls()
350 self.file_mock.check_calls()
351
352 def tearDown(self):
353 TimeMock.stop()
354 self.global_variables_mock.stop()
355 self.file_mock.stop()
356 self.requests_mock.stop()
357 self.database_mock.stop()
358
diff --git a/tests/helper.py b/tests/helper.py
index 4548b16..b85bf3a 100644
--- a/tests/helper.py
+++ b/tests/helper.py
@@ -6,9 +6,19 @@ import requests_mock
6from io import StringIO 6from io import StringIO
7import portfolio, market, main, store 7import portfolio, market, main, store
8 8
9__all__ = ["unittest", "WebMockTestCase", "mock", "D", 9__all__ = ["limits", "unittest", "WebMockTestCase", "mock", "D",
10 "StringIO"] 10 "StringIO"]
11 11
12limits = ["acceptance", "unit"]
13for test_type in limits:
14 if "--no{}".format(test_type) in sys.argv:
15 sys.argv.remove("--no{}".format(test_type))
16 limits.remove(test_type)
17 if "--only{}".format(test_type) in sys.argv:
18 sys.argv.remove("--only{}".format(test_type))
19 limits = [test_type]
20 break
21
12class WebMockTestCase(unittest.TestCase): 22class WebMockTestCase(unittest.TestCase):
13 import time 23 import time
14 24
diff --git a/tests/test_acceptance.py b/tests/test_acceptance.py
new file mode 100644
index 0000000..77a6cca
--- /dev/null
+++ b/tests/test_acceptance.py
@@ -0,0 +1,25 @@
1from .helper import limits
2from tests.acceptance import AcceptanceTestCase
3
4import unittest
5import glob
6
7__all__ = []
8
9for dirfile in glob.glob("tests/acceptance/**/*/", recursive=True):
10 json_files = glob.glob("{}/*.json".format(dirfile))
11 log_files = glob.glob("{}/*.log".format(dirfile))
12 if len(json_files) > 0:
13 name = dirfile.replace("tests/acceptance/", "").replace("/", "_")[0:-1]
14 cname = "".join(list(map(lambda x: x.capitalize(), name.split("_"))))
15
16 globals()[cname] = unittest.skipUnless("acceptance" in limits, "Acceptance skipped")(
17 type(cname, (AcceptanceTestCase, unittest.TestCase), {
18 "log_files": log_files,
19 "files": json_files,
20 "test_{}".format(name): AcceptanceTestCase.base_test
21 })
22 )
23 __all__.append(cname)
24
25
diff --git a/tests/test_ccxt_wrapper.py b/tests/test_ccxt_wrapper.py
index 18feab3..10e334d 100644
--- a/tests/test_ccxt_wrapper.py
+++ b/tests/test_ccxt_wrapper.py
@@ -1,7 +1,8 @@
1from .helper import unittest, mock, D 1from .helper import limits, unittest, mock, D
2import requests_mock 2import requests_mock
3import market 3import market
4 4
5@unittest.skipUnless("unit" in limits, "Unit skipped")
5class poloniexETest(unittest.TestCase): 6class poloniexETest(unittest.TestCase):
6 def setUp(self): 7 def setUp(self):
7 super().setUp() 8 super().setUp()
diff --git a/tests/test_main.py b/tests/test_main.py
index cee89ce..d2f8029 100644
--- a/tests/test_main.py
+++ b/tests/test_main.py
@@ -1,6 +1,7 @@
1from .helper import * 1from .helper import *
2import main, market 2import main, market
3 3
4@unittest.skipUnless("unit" in limits, "Unit skipped")
4class MainTest(WebMockTestCase): 5class MainTest(WebMockTestCase):
5 def test_make_order(self): 6 def test_make_order(self):
6 self.m.get_ticker.return_value = { 7 self.m.get_ticker.return_value = {
diff --git a/tests/test_market.py b/tests/test_market.py
index fd23162..14b23b5 100644
--- a/tests/test_market.py
+++ b/tests/test_market.py
@@ -2,6 +2,7 @@ from .helper import *
2import market, store, portfolio 2import market, store, portfolio
3import datetime 3import datetime
4 4
5@unittest.skipUnless("unit" in limits, "Unit skipped")
5class MarketTest(WebMockTestCase): 6class MarketTest(WebMockTestCase):
6 def setUp(self): 7 def setUp(self):
7 super().setUp() 8 super().setUp()
@@ -729,6 +730,7 @@ class MarketTest(WebMockTestCase):
729 store_report.assert_called_once() 730 store_report.assert_called_once()
730 731
731 732
733@unittest.skipUnless("unit" in limits, "Unit skipped")
732class ProcessorTest(WebMockTestCase): 734class ProcessorTest(WebMockTestCase):
733 def test_values(self): 735 def test_values(self):
734 processor = market.Processor(self.m) 736 processor = market.Processor(self.m)
diff --git a/tests/test_portfolio.py b/tests/test_portfolio.py
index 14dc995..4d78996 100644
--- a/tests/test_portfolio.py
+++ b/tests/test_portfolio.py
@@ -2,6 +2,7 @@ from .helper import *
2import portfolio 2import portfolio
3import datetime 3import datetime
4 4
5@unittest.skipUnless("unit" in limits, "Unit skipped")
5class ComputationTest(WebMockTestCase): 6class ComputationTest(WebMockTestCase):
6 def test_compute_value(self): 7 def test_compute_value(self):
7 compute = mock.Mock() 8 compute = mock.Mock()
@@ -25,6 +26,7 @@ class ComputationTest(WebMockTestCase):
25 portfolio.Computation.compute_value("foo", "bid", compute_value="test") 26 portfolio.Computation.compute_value("foo", "bid", compute_value="test")
26 compute.assert_called_with("foo", "bid") 27 compute.assert_called_with("foo", "bid")
27 28
29@unittest.skipUnless("unit" in limits, "Unit skipped")
28class TradeTest(WebMockTestCase): 30class TradeTest(WebMockTestCase):
29 31
30 def test_values_assertion(self): 32 def test_values_assertion(self):
@@ -609,6 +611,7 @@ class TradeTest(WebMockTestCase):
609 self.assertEqual("ETH", as_json["currency"]) 611 self.assertEqual("ETH", as_json["currency"])
610 self.assertEqual("BTC", as_json["base_currency"]) 612 self.assertEqual("BTC", as_json["base_currency"])
611 613
614@unittest.skipUnless("unit" in limits, "Unit skipped")
612class BalanceTest(WebMockTestCase): 615class BalanceTest(WebMockTestCase):
613 def test_values(self): 616 def test_values(self):
614 balance = portfolio.Balance("BTC", { 617 balance = portfolio.Balance("BTC", {
@@ -684,6 +687,7 @@ class BalanceTest(WebMockTestCase):
684 self.assertEqual(D(0), as_json["margin_available"]) 687 self.assertEqual(D(0), as_json["margin_available"])
685 self.assertEqual(D(0), as_json["margin_borrowed"]) 688 self.assertEqual(D(0), as_json["margin_borrowed"])
686 689
690@unittest.skipUnless("unit" in limits, "Unit skipped")
687class OrderTest(WebMockTestCase): 691class OrderTest(WebMockTestCase):
688 def test_values(self): 692 def test_values(self):
689 order = portfolio.Order("buy", portfolio.Amount("ETH", 10), 693 order = portfolio.Order("buy", portfolio.Amount("ETH", 10),
@@ -1745,6 +1749,7 @@ class OrderTest(WebMockTestCase):
1745 result = order.retrieve_order() 1749 result = order.retrieve_order()
1746 self.assertFalse(result) 1750 self.assertFalse(result)
1747 1751
1752@unittest.skipUnless("unit" in limits, "Unit skipped")
1748class MouvementTest(WebMockTestCase): 1753class MouvementTest(WebMockTestCase):
1749 def test_values(self): 1754 def test_values(self):
1750 mouvement = portfolio.Mouvement("ETH", "BTC", { 1755 mouvement = portfolio.Mouvement("ETH", "BTC", {
@@ -1802,6 +1807,7 @@ class MouvementTest(WebMockTestCase):
1802 self.assertEqual("BTC", as_json["base_currency"]) 1807 self.assertEqual("BTC", as_json["base_currency"])
1803 self.assertEqual("ETH", as_json["currency"]) 1808 self.assertEqual("ETH", as_json["currency"])
1804 1809
1810@unittest.skipUnless("unit" in limits, "Unit skipped")
1805class AmountTest(WebMockTestCase): 1811class AmountTest(WebMockTestCase):
1806 def test_values(self): 1812 def test_values(self):
1807 amount = portfolio.Amount("BTC", "0.65") 1813 amount = portfolio.Amount("BTC", "0.65")
diff --git a/tests/test_store.py b/tests/test_store.py
index e281adb..ffd2645 100644
--- a/tests/test_store.py
+++ b/tests/test_store.py
@@ -4,12 +4,14 @@ import datetime
4import threading 4import threading
5import market, portfolio, store 5import market, portfolio, store
6 6
7@unittest.skipUnless("unit" in limits, "Unit skipped")
7class NoopLockTest(unittest.TestCase): 8class NoopLockTest(unittest.TestCase):
8 def test_with(self): 9 def test_with(self):
9 noop_lock = store.NoopLock() 10 noop_lock = store.NoopLock()
10 with noop_lock: 11 with noop_lock:
11 self.assertTrue(True) 12 self.assertTrue(True)
12 13
14@unittest.skipUnless("unit" in limits, "Unit skipped")
13class LockedVarTest(unittest.TestCase): 15class LockedVarTest(unittest.TestCase):
14 16
15 def test_values(self): 17 def test_values(self):
@@ -61,6 +63,7 @@ class LockedVarTest(unittest.TestCase):
61 thread3.join() 63 thread3.join()
62 self.assertEqual("Bar", locked_var.get()[0:3]) 64 self.assertEqual("Bar", locked_var.get()[0:3])
63 65
66@unittest.skipUnless("unit" in limits, "Unit skipped")
64class TradeStoreTest(WebMockTestCase): 67class TradeStoreTest(WebMockTestCase):
65 def test_compute_trades(self): 68 def test_compute_trades(self):
66 self.m.balances.currencies.return_value = ["XMR", "DASH", "XVG", "BTC", "ETH"] 69 self.m.balances.currencies.return_value = ["XMR", "DASH", "XVG", "BTC", "ETH"]
@@ -285,6 +288,7 @@ class TradeStoreTest(WebMockTestCase):
285 288
286 self.assertEqual([trade_mock1, trade_mock2], trade_store.pending) 289 self.assertEqual([trade_mock1, trade_mock2], trade_store.pending)
287 290
291@unittest.skipUnless("unit" in limits, "Unit skipped")
288class BalanceStoreTest(WebMockTestCase): 292class BalanceStoreTest(WebMockTestCase):
289 def setUp(self): 293 def setUp(self):
290 super().setUp() 294 super().setUp()
@@ -437,6 +441,7 @@ class BalanceStoreTest(WebMockTestCase):
437 self.assertEqual(1, as_json["BTC"]) 441 self.assertEqual(1, as_json["BTC"])
438 self.assertEqual(2, as_json["ETH"]) 442 self.assertEqual(2, as_json["ETH"])
439 443
444@unittest.skipUnless("unit" in limits, "Unit skipped")
440class ReportStoreTest(WebMockTestCase): 445class ReportStoreTest(WebMockTestCase):
441 def test_add_log(self): 446 def test_add_log(self):
442 with self.subTest(market=self.m): 447 with self.subTest(market=self.m):
@@ -997,6 +1002,7 @@ class ReportStoreTest(WebMockTestCase):
997 'action': 'Hey' 1002 'action': 'Hey'
998 }) 1003 })
999 1004
1005@unittest.skipUnless("unit" in limits, "Unit skipped")
1000class PortfolioTest(WebMockTestCase): 1006class PortfolioTest(WebMockTestCase):
1001 def setUp(self): 1007 def setUp(self):
1002 super().setUp() 1008 super().setUp()