aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.ini12
-rw-r--r--main.py82
-rw-r--r--market.py11
-rw-r--r--requirements.txt1
-rw-r--r--store.py4
-rw-r--r--test.py149
6 files changed, 133 insertions, 126 deletions
diff --git a/config.ini b/config.ini
index 50cbd1b..4d261c4 100644
--- a/config.ini
+++ b/config.ini
@@ -1,9 +1,9 @@
1[postgresql] 1[postgresql]
2host = localhost 2db-host = localhost
3port = 5432 3db-port = 5432
4user = cryptoportfolio 4db-user = cryptoportfolio
5password = cryptoportfolio 5db-password = cryptoportfolio
6database = cryptoportfolio 6db-database = cryptoportfolio
7 7
8[app] 8[app]
9report_path = reports 9report-path = reports
diff --git a/main.py b/main.py
index 4462192..b68d540 100644
--- a/main.py
+++ b/main.py
@@ -1,6 +1,5 @@
1from datetime import datetime 1from datetime import datetime
2import argparse 2import configargparse
3import configparser
4import psycopg2 3import psycopg2
5import os 4import os
6import sys 5import sys
@@ -80,31 +79,35 @@ def fetch_markets(pg_config, user):
80 for row in cursor: 79 for row in cursor:
81 yield row 80 yield row
82 81
83def parse_config(config_file): 82def parse_config(args):
84 config = configparser.ConfigParser() 83 pg_config = {
85 config.read(config_file) 84 "host": args.db_host,
85 "port": args.db_port,
86 "user": args.db_user,
87 "password": args.db_password,
88 "database": args.db_database,
89 }
90 del(args.db_host)
91 del(args.db_port)
92 del(args.db_user)
93 del(args.db_password)
94 del(args.db_database)
86 95
87 if "postgresql" not in config: 96 report_path = args.report_path
88 print("no configuration for postgresql in config file")
89 sys.exit(1)
90 97
91 if "app" in config and "report_path" in config["app"]: 98 if report_path is not None and not \
92 report_path = config["app"]["report_path"] 99 os.path.exists(report_path):
100 os.makedirs(report_path)
93 101
94 if not os.path.exists(report_path): 102 return pg_config
95 os.makedirs(report_path)
96 else:
97 report_path = None
98
99 return [config["postgresql"], report_path]
100 103
101def parse_args(argv): 104def parse_args(argv):
102 parser = argparse.ArgumentParser( 105 parser = configargparse.ArgumentParser(
103 description="Run the trade bot") 106 description="Run the trade bot.")
104 107
105 parser.add_argument("-c", "--config", 108 parser.add_argument("-c", "--config",
106 default="config.ini", 109 default="config.ini",
107 required=False, 110 required=False, is_config_file=True,
108 help="Config file to load (default: config.ini)") 111 help="Config file to load (default: config.ini)")
109 parser.add_argument("--before", 112 parser.add_argument("--before",
110 default=False, action='store_const', const=True, 113 default=False, action='store_const', const=True,
@@ -125,21 +128,32 @@ def parse_args(argv):
125 help="Do a different action than trading (add several times to chain)") 128 help="Do a different action than trading (add several times to chain)")
126 parser.add_argument("--parallel", action='store_true', default=True, dest="parallel") 129 parser.add_argument("--parallel", action='store_true', default=True, dest="parallel")
127 parser.add_argument("--no-parallel", action='store_false', dest="parallel") 130 parser.add_argument("--no-parallel", action='store_false', dest="parallel")
128 131 parser.add_argument("--report-db", action='store_true', default=True, dest="report_db",
129 args = parser.parse_args(argv) 132 help="Store report to database (default)")
130 133 parser.add_argument("--no-report-db", action='store_false', dest="report_db",
131 if not os.path.exists(args.config): 134 help="Don't store report to database")
132 print("no config file found, exiting") 135 parser.add_argument("--report-path", required=False,
133 sys.exit(1) 136 help="Where to store the reports (default: absent, don't store)")
134 137 parser.add_argument("--no-report-path", action='store_const', dest='report_path', const=None,
135 return args 138 help="Don't store the report to file (default)")
136 139 parser.add_argument("--db-host", default="localhost",
137def process(market_config, market_id, user_id, args, report_path, pg_config): 140 help="Host access to database (default: localhost)")
141 parser.add_argument("--db-port", default=5432,
142 help="Port access to database (default: 5432)")
143 parser.add_argument("--db-user", default="cryptoportfolio",
144 help="User access to database (default: cryptoportfolio)")
145 parser.add_argument("--db-password", default="cryptoportfolio",
146 help="Password access to database (default: cryptoportfolio)")
147 parser.add_argument("--db-database", default="cryptoportfolio",
148 help="Database access to database (default: cryptoportfolio)")
149
150 return parser.parse_args(argv)
151
152def process(market_config, market_id, user_id, args, pg_config):
138 try: 153 try:
139 market.Market\ 154 market.Market\
140 .from_config(market_config, args, 155 .from_config(market_config, args, market_id=market_id,
141 pg_config=pg_config, market_id=market_id, 156 pg_config=pg_config, user_id=user_id)\
142 user_id=user_id, report_path=report_path)\
143 .process(args.action, before=args.before, after=args.after) 157 .process(args.action, before=args.before, after=args.after)
144 except Exception as e: 158 except Exception as e:
145 print("{}: {}".format(e.__class__.__name__, e)) 159 print("{}: {}".format(e.__class__.__name__, e))
@@ -147,7 +161,7 @@ def process(market_config, market_id, user_id, args, report_path, pg_config):
147def main(argv): 161def main(argv):
148 args = parse_args(argv) 162 args = parse_args(argv)
149 163
150 pg_config, report_path = parse_config(args.config) 164 pg_config = parse_config(args)
151 165
152 if args.parallel: 166 if args.parallel:
153 import threading 167 import threading
@@ -159,7 +173,7 @@ def main(argv):
159 process_ = process 173 process_ = process
160 174
161 for market_id, market_config, user_id in fetch_markets(pg_config, args.user): 175 for market_id, market_config, user_id in fetch_markets(pg_config, args.user):
162 process_(market_config, market_id, user_id, args, report_path, pg_config) 176 process_(market_config, market_id, user_id, args, pg_config)
163 177
164if __name__ == '__main__': # pragma: no cover 178if __name__ == '__main__': # pragma: no cover
165 main(sys.argv[1:]) 179 main(sys.argv[1:])
diff --git a/market.py b/market.py
index 10d1ad8..e16641c 100644
--- a/market.py
+++ b/market.py
@@ -25,11 +25,10 @@ class Market:
25 self.balances = BalanceStore(self) 25 self.balances = BalanceStore(self)
26 self.processor = Processor(self) 26 self.processor = Processor(self)
27 27
28 for key in ["user_id", "market_id", "report_path", "pg_config"]: 28 for key in ["user_id", "market_id", "pg_config"]:
29 setattr(self, key, kwargs.get(key, None)) 29 setattr(self, key, kwargs.get(key, None))
30 30
31 self.report.log_market(self.args, self.user_id, self.market_id, 31 self.report.log_market(self.args, self.user_id, self.market_id)
32 self.report_path, self.debug)
33 32
34 @classmethod 33 @classmethod
35 def from_config(cls, config, args, **kwargs): 34 def from_config(cls, config, args, **kwargs):
@@ -42,14 +41,14 @@ class Market:
42 def store_report(self): 41 def store_report(self):
43 self.report.merge(Portfolio.report) 42 self.report.merge(Portfolio.report)
44 date = datetime.now() 43 date = datetime.now()
45 if self.report_path is not None: 44 if self.args.report_path is not None:
46 self.store_file_report(date) 45 self.store_file_report(date)
47 if self.pg_config is not None: 46 if self.pg_config is not None and self.args.report_db:
48 self.store_database_report(date) 47 self.store_database_report(date)
49 48
50 def store_file_report(self, date): 49 def store_file_report(self, date):
51 try: 50 try:
52 report_file = "{}/{}_{}".format(self.report_path, date.isoformat(), self.user_id) 51 report_file = "{}/{}_{}".format(self.args.report_path, date.isoformat(), self.user_id)
53 with open(report_file + ".json", "w") as f: 52 with open(report_file + ".json", "w") as f:
54 f.write(self.report.to_json()) 53 f.write(self.report.to_json())
55 with open(report_file + ".log", "w") as f: 54 with open(report_file + ".log", "w") as f:
diff --git a/requirements.txt b/requirements.txt
index 1bc76ec..2451c80 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,3 +5,4 @@ requests_mock==1.4.0
5psycopg2==2.7.4 5psycopg2==2.7.4
6retry==0.9.2 6retry==0.9.2
7cachetools==2.0.1 7cachetools==2.0.1
8configargparse==0.12.0
diff --git a/store.py b/store.py
index 3f3718f..67e8a8f 100644
--- a/store.py
+++ b/store.py
@@ -222,15 +222,13 @@ class ReportStore:
222 "action": action, 222 "action": action,
223 }) 223 })
224 224
225 def log_market(self, args, user_id, market_id, report_path, debug): 225 def log_market(self, args, user_id, market_id):
226 self.add_log({ 226 self.add_log({
227 "type": "market", 227 "type": "market",
228 "commit": "$Format:%H$", 228 "commit": "$Format:%H$",
229 "args": vars(args), 229 "args": vars(args),
230 "user_id": user_id, 230 "user_id": user_id,
231 "market_id": market_id, 231 "market_id": market_id,
232 "report_path": report_path,
233 "debug": debug,
234 }) 232 })
235 233
236class BalanceStore: 234class BalanceStore:
diff --git a/test.py b/test.py
index bf679bf..854e27b 100644
--- a/test.py
+++ b/test.py
@@ -23,8 +23,9 @@ for test_type in limits:
23class WebMockTestCase(unittest.TestCase): 23class WebMockTestCase(unittest.TestCase):
24 import time 24 import time
25 25
26 def market_args(self, debug=False, quiet=False): 26 def market_args(self, debug=False, quiet=False, report_path=None, **kwargs):
27 return type('Args', (object,), { "debug": debug, "quiet": quiet })() 27 return main.configargparse.Namespace(report_path=report_path,
28 debug=debug, quiet=quiet, **kwargs)
28 29
29 def setUp(self): 30 def setUp(self):
30 super().setUp() 31 super().setUp()
@@ -1632,7 +1633,8 @@ class MarketTest(WebMockTestCase):
1632 1633
1633 def test_store_file_report(self): 1634 def test_store_file_report(self):
1634 file_open = mock.mock_open() 1635 file_open = mock.mock_open()
1635 m = market.Market(self.ccxt, self.market_args(), report_path="present", user_id=1) 1636 m = market.Market(self.ccxt,
1637 self.market_args(report_path="present"), user_id=1)
1636 with self.subTest(file="present"),\ 1638 with self.subTest(file="present"),\
1637 mock.patch("market.open", file_open),\ 1639 mock.patch("market.open", file_open),\
1638 mock.patch.object(m, "report") as report,\ 1640 mock.patch.object(m, "report") as report,\
@@ -1649,7 +1651,7 @@ class MarketTest(WebMockTestCase):
1649 file_open().write.assert_any_call("Foo\nBar") 1651 file_open().write.assert_any_call("Foo\nBar")
1650 m.report.to_json.assert_called_once_with() 1652 m.report.to_json.assert_called_once_with()
1651 1653
1652 m = market.Market(self.ccxt, self.market_args(), report_path="error", user_id=1) 1654 m = market.Market(self.ccxt, self.market_args(report_path="error"), user_id=1)
1653 with self.subTest(file="error"),\ 1655 with self.subTest(file="error"),\
1654 mock.patch("market.open") as file_open,\ 1656 mock.patch("market.open") as file_open,\
1655 mock.patch.object(m, "report") as report,\ 1657 mock.patch.object(m, "report") as report,\
@@ -1697,7 +1699,7 @@ class MarketTest(WebMockTestCase):
1697 self.assertEqual(stdout_mock.getvalue(), "impossible to store report to database: Exception; Bouh\n") 1699 self.assertEqual(stdout_mock.getvalue(), "impossible to store report to database: Exception; Bouh\n")
1698 1700
1699 def test_store_report(self): 1701 def test_store_report(self):
1700 m = market.Market(self.ccxt, self.market_args(), user_id=1) 1702 m = market.Market(self.ccxt, self.market_args(report_db=False), user_id=1)
1701 with self.subTest(file=None, pg_config=None),\ 1703 with self.subTest(file=None, pg_config=None),\
1702 mock.patch.object(m, "report") as report,\ 1704 mock.patch.object(m, "report") as report,\
1703 mock.patch.object(m, "store_database_report") as db_report,\ 1705 mock.patch.object(m, "store_database_report") as db_report,\
@@ -1709,7 +1711,7 @@ class MarketTest(WebMockTestCase):
1709 db_report.assert_not_called() 1711 db_report.assert_not_called()
1710 1712
1711 report.reset_mock() 1713 report.reset_mock()
1712 m = market.Market(self.ccxt, self.market_args(), report_path="present", user_id=1) 1714 m = market.Market(self.ccxt, self.market_args(report_db=False, report_path="present"), user_id=1)
1713 with self.subTest(file="present", pg_config=None),\ 1715 with self.subTest(file="present", pg_config=None),\
1714 mock.patch.object(m, "report") as report,\ 1716 mock.patch.object(m, "report") as report,\
1715 mock.patch.object(m, "store_file_report") as file_report,\ 1717 mock.patch.object(m, "store_file_report") as file_report,\
@@ -1725,7 +1727,23 @@ class MarketTest(WebMockTestCase):
1725 db_report.assert_not_called() 1727 db_report.assert_not_called()
1726 1728
1727 report.reset_mock() 1729 report.reset_mock()
1728 m = market.Market(self.ccxt, self.market_args(), pg_config="present", user_id=1) 1730 m = market.Market(self.ccxt, self.market_args(report_db=True, report_path="present"), user_id=1)
1731 with self.subTest(file="present", pg_config=None, report_db=True),\
1732 mock.patch.object(m, "report") as report,\
1733 mock.patch.object(m, "store_file_report") as file_report,\
1734 mock.patch.object(m, "store_database_report") as db_report,\
1735 mock.patch.object(market, "datetime") as time_mock:
1736
1737 time_mock.now.return_value = datetime.datetime(2018, 2, 25)
1738
1739 m.store_report()
1740
1741 report.merge.assert_called_with(store.Portfolio.report)
1742 file_report.assert_called_once_with(datetime.datetime(2018, 2, 25))
1743 db_report.assert_not_called()
1744
1745 report.reset_mock()
1746 m = market.Market(self.ccxt, self.market_args(report_db=True), pg_config="present", user_id=1)
1729 with self.subTest(file=None, pg_config="present"),\ 1747 with self.subTest(file=None, pg_config="present"),\
1730 mock.patch.object(m, "report") as report,\ 1748 mock.patch.object(m, "report") as report,\
1731 mock.patch.object(m, "store_file_report") as file_report,\ 1749 mock.patch.object(m, "store_file_report") as file_report,\
@@ -1741,8 +1759,8 @@ class MarketTest(WebMockTestCase):
1741 db_report.assert_called_once_with(datetime.datetime(2018, 2, 25)) 1759 db_report.assert_called_once_with(datetime.datetime(2018, 2, 25))
1742 1760
1743 report.reset_mock() 1761 report.reset_mock()
1744 m = market.Market(self.ccxt, self.market_args(), 1762 m = market.Market(self.ccxt, self.market_args(report_db=True, report_path="present"),
1745 pg_config="pg_config", report_path="present", user_id=1) 1763 pg_config="pg_config", user_id=1)
1746 with self.subTest(file="present", pg_config="present"),\ 1764 with self.subTest(file="present", pg_config="present"),\
1747 mock.patch.object(m, "report") as report,\ 1765 mock.patch.object(m, "report") as report,\
1748 mock.patch.object(m, "store_file_report") as file_report,\ 1766 mock.patch.object(m, "store_file_report") as file_report,\
@@ -4383,20 +4401,14 @@ class ReportStoreTest(WebMockTestCase):
4383 @mock.patch.object(market.ReportStore, "add_log") 4401 @mock.patch.object(market.ReportStore, "add_log")
4384 def test_log_market(self, add_log): 4402 def test_log_market(self, add_log):
4385 report_store = market.ReportStore(self.m) 4403 report_store = market.ReportStore(self.m)
4386 class Args:
4387 def __init__(self):
4388 self.debug = True
4389 self.quiet = False
4390 4404
4391 report_store.log_market(Args(), 4, 1, "report", True) 4405 report_store.log_market(self.market_args(debug=True, quiet=False), 4, 1)
4392 add_log.assert_called_once_with({ 4406 add_log.assert_called_once_with({
4393 "type": "market", 4407 "type": "market",
4394 "commit": "$Format:%H$", 4408 "commit": "$Format:%H$",
4395 "args": { "debug": True, "quiet": False }, 4409 "args": { "report_path": None, "debug": True, "quiet": False },
4396 "user_id": 4, 4410 "user_id": 4,
4397 "market_id": 1, 4411 "market_id": 1,
4398 "report_path": "report",
4399 "debug": True
4400 }) 4412 })
4401 4413
4402 @mock.patch.object(market.ReportStore, "print_log") 4414 @mock.patch.object(market.ReportStore, "print_log")
@@ -4603,16 +4615,16 @@ class MainTest(WebMockTestCase):
4603 args_mock.after = "after" 4615 args_mock.after = "after"
4604 self.assertEqual("", stdout_mock.getvalue()) 4616 self.assertEqual("", stdout_mock.getvalue())
4605 4617
4606 main.process("config", 3, 1, args_mock, "report_path", "pg_config") 4618 main.process("config", 3, 1, args_mock, "pg_config")
4607 4619
4608 market_mock.from_config.assert_has_calls([ 4620 market_mock.from_config.assert_has_calls([
4609 mock.call("config", args_mock, pg_config="pg_config", market_id=3, user_id=1, report_path="report_path"), 4621 mock.call("config", args_mock, pg_config="pg_config", market_id=3, user_id=1),
4610 mock.call().process("action", before="before", after="after"), 4622 mock.call().process("action", before="before", after="after"),
4611 ]) 4623 ])
4612 4624
4613 with self.subTest(exception=True): 4625 with self.subTest(exception=True):
4614 market_mock.from_config.side_effect = Exception("boo") 4626 market_mock.from_config.side_effect = Exception("boo")
4615 main.process(3, "config", 1, "report_path", args_mock, "pg_config") 4627 main.process(3, "config", 1, args_mock, "pg_config")
4616 self.assertEqual("Exception: boo\n", stdout_mock.getvalue()) 4628 self.assertEqual("Exception: boo\n", stdout_mock.getvalue())
4617 4629
4618 def test_main(self): 4630 def test_main(self):
@@ -4624,24 +4636,23 @@ class MainTest(WebMockTestCase):
4624 4636
4625 args_mock = mock.Mock() 4637 args_mock = mock.Mock()
4626 args_mock.parallel = False 4638 args_mock.parallel = False
4627 args_mock.config = "config"
4628 args_mock.user = "user" 4639 args_mock.user = "user"
4629 parse_args.return_value = args_mock 4640 parse_args.return_value = args_mock
4630 4641
4631 parse_config.return_value = ["pg_config", "report_path"] 4642 parse_config.return_value = "pg_config"
4632 4643
4633 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] 4644 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]
4634 4645
4635 main.main(["Foo", "Bar"]) 4646 main.main(["Foo", "Bar"])
4636 4647
4637 parse_args.assert_called_with(["Foo", "Bar"]) 4648 parse_args.assert_called_with(["Foo", "Bar"])
4638 parse_config.assert_called_with("config") 4649 parse_config.assert_called_with(args_mock)
4639 fetch_markets.assert_called_with("pg_config", "user") 4650 fetch_markets.assert_called_with("pg_config", "user")
4640 4651
4641 self.assertEqual(2, process.call_count) 4652 self.assertEqual(2, process.call_count)
4642 process.assert_has_calls([ 4653 process.assert_has_calls([
4643 mock.call("config1", 3, 1, args_mock, "report_path", "pg_config"), 4654 mock.call("config1", 3, 1, args_mock, "pg_config"),
4644 mock.call("config2", 1, 2, args_mock, "report_path", "pg_config"), 4655 mock.call("config2", 1, 2, args_mock, "pg_config"),
4645 ]) 4656 ])
4646 with self.subTest(parallel=True): 4657 with self.subTest(parallel=True):
4647 with mock.patch("main.parse_args") as parse_args,\ 4658 with mock.patch("main.parse_args") as parse_args,\
@@ -4652,79 +4663,66 @@ class MainTest(WebMockTestCase):
4652 4663
4653 args_mock = mock.Mock() 4664 args_mock = mock.Mock()
4654 args_mock.parallel = True 4665 args_mock.parallel = True
4655 args_mock.config = "config"
4656 args_mock.user = "user" 4666 args_mock.user = "user"
4657 parse_args.return_value = args_mock 4667 parse_args.return_value = args_mock
4658 4668
4659 parse_config.return_value = ["pg_config", "report_path"] 4669 parse_config.return_value = "pg_config"
4660 4670
4661 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]] 4671 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]
4662 4672
4663 main.main(["Foo", "Bar"]) 4673 main.main(["Foo", "Bar"])
4664 4674
4665 parse_args.assert_called_with(["Foo", "Bar"]) 4675 parse_args.assert_called_with(["Foo", "Bar"])
4666 parse_config.assert_called_with("config") 4676 parse_config.assert_called_with(args_mock)
4667 fetch_markets.assert_called_with("pg_config", "user") 4677 fetch_markets.assert_called_with("pg_config", "user")
4668 4678
4669 start.assert_called_once_with() 4679 start.assert_called_once_with()
4670 self.assertEqual(2, process.call_count) 4680 self.assertEqual(2, process.call_count)
4671 process.assert_has_calls([ 4681 process.assert_has_calls([
4672 mock.call.__bool__(), 4682 mock.call.__bool__(),
4673 mock.call("config1", 3, 1, args_mock, "report_path", "pg_config"), 4683 mock.call("config1", 3, 1, args_mock, "pg_config"),
4674 mock.call.__bool__(), 4684 mock.call.__bool__(),
4675 mock.call("config2", 1, 2, args_mock, "report_path", "pg_config"), 4685 mock.call("config2", 1, 2, args_mock, "pg_config"),
4676 ]) 4686 ])
4677 4687
4678 @mock.patch.object(main.sys, "exit") 4688 @mock.patch.object(main.sys, "exit")
4679 @mock.patch("main.configparser")
4680 @mock.patch("main.os") 4689 @mock.patch("main.os")
4681 def test_parse_config(self, os, configparser, exit): 4690 def test_parse_config(self, os, exit):
4682 with self.subTest(pg_config=True, report_path=None): 4691 with self.subTest(report_path=None):
4683 config_mock = mock.MagicMock() 4692 args = main.configargparse.Namespace(**{
4684 configparser.ConfigParser.return_value = config_mock 4693 "db_host": "host",
4685 def config(element): 4694 "db_port": "port",
4686 return element == "postgresql" 4695 "db_user": "user",
4687 4696 "db_password": "password",
4688 config_mock.__contains__.side_effect = config 4697 "db_database": "database",
4689 config_mock.__getitem__.return_value = "pg_config" 4698 "report_path": None,
4690 4699 })
4691 result = main.parse_config("configfile")
4692
4693 config_mock.read.assert_called_with("configfile")
4694
4695 self.assertEqual(["pg_config", None], result)
4696
4697 with self.subTest(pg_config=True, report_path="present"):
4698 config_mock = mock.MagicMock()
4699 configparser.ConfigParser.return_value = config_mock
4700 4700
4701 config_mock.__contains__.return_value = True 4701 result = main.parse_config(args)
4702 config_mock.__getitem__.side_effect = [ 4702 self.assertEqual({ "host": "host", "port": "port", "user":
4703 {"report_path": "report_path"}, 4703 "user", "password": "password", "database": "database"
4704 {"report_path": "report_path"}, 4704 }, result)
4705 "pg_config", 4705 with self.assertRaises(AttributeError):
4706 ] 4706 args.db_password
4707
4708 with self.subTest(report_path="present"):
4709 args = main.configargparse.Namespace(**{
4710 "db_host": "host",
4711 "db_port": "port",
4712 "db_user": "user",
4713 "db_password": "password",
4714 "db_database": "database",
4715 "report_path": "report_path",
4716 })
4707 4717
4708 os.path.exists.return_value = False 4718 os.path.exists.return_value = False
4709 result = main.parse_config("configfile")
4710 4719
4711 config_mock.read.assert_called_with("configfile") 4720 result = main.parse_config(args)
4712 self.assertEqual(["pg_config", "report_path"], result) 4721
4713 os.path.exists.assert_called_once_with("report_path") 4722 os.path.exists.assert_called_once_with("report_path")
4714 os.makedirs.assert_called_once_with("report_path") 4723 os.makedirs.assert_called_once_with("report_path")
4715 4724
4716 with self.subTest(pg_config=False),\ 4725 def test_parse_args(self):
4717 mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock:
4718 config_mock = mock.MagicMock()
4719 configparser.ConfigParser.return_value = config_mock
4720 result = main.parse_config("configfile")
4721
4722 config_mock.read.assert_called_with("configfile")
4723 exit.assert_called_once_with(1)
4724 self.assertEqual("no configuration for postgresql in config file\n", stdout_mock.getvalue())
4725
4726 @mock.patch.object(main.sys, "exit")
4727 def test_parse_args(self, exit):
4728 with self.subTest(config="config.ini"): 4726 with self.subTest(config="config.ini"):
4729 args = main.parse_args([]) 4727 args = main.parse_args([])
4730 self.assertEqual("config.ini", args.config) 4728 self.assertEqual("config.ini", args.config)
@@ -4737,13 +4735,10 @@ class MainTest(WebMockTestCase):
4737 self.assertTrue(args.after) 4735 self.assertTrue(args.after)
4738 self.assertTrue(args.debug) 4736 self.assertTrue(args.debug)
4739 4737
4740 exit.assert_not_called() 4738 with self.subTest(config="inexistant"), \
4741 4739 self.assertRaises(SystemExit), \
4742 with self.subTest(config="inexistant"),\ 4740 mock.patch('sys.stderr', new_callable=StringIO) as stdout_mock:
4743 mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock:
4744 args = main.parse_args(["--config", "foo.bar"]) 4741 args = main.parse_args(["--config", "foo.bar"])
4745 exit.assert_called_once_with(1)
4746 self.assertEqual("no config file found, exiting\n", stdout_mock.getvalue())
4747 4742
4748 @mock.patch.object(main, "psycopg2") 4743 @mock.patch.object(main, "psycopg2")
4749 def test_fetch_markets(self, psycopg2): 4744 def test_fetch_markets(self, psycopg2):