aboutsummaryrefslogblamecommitdiff
path: root/tests/test_main.py
blob: 1864c062af75fd1bd7d7c0eabca80127e6a2dc13 (plain) (tree)
1
2
3
4


                     
                                                      




































































































                                                                                                









                                                                                        



















                                                                                                   
                                                   

                                                      
                                                                       




                                                                              
                                                       













                                                                            





                                                                                   
                                                        


                                                       

                                                          





                                                                       

                                                                         





                                                   





                                                                                   
                                                        
 
                                              



                                                       
                                                          
                                         
                                                          
                      












                                                                       


















                                                                                   





                                                                                   



                                          


                                                                      

                                                    
                                      


                                    











                                                                      
 


                                               
 


                                                                      

                                                    
                                   


                                    


                                              
 


                                                                      

                                                    
                                      




                                               
                                   





















                                                                               

                                        


                                                              
                                              

                                     
                                                 
 



                                                                                                       

                                        
                                              
 




                                                                                                                             
from .helper import *
import main, market

@unittest.skipUnless("unit" in limits, "Unit skipped")
class MainTest(WebMockTestCase):
    def test_make_order(self):
        self.m.get_ticker.return_value = {
                "inverted": False,
                "average": D("0.1"),
                "bid": D("0.09"),
                "ask": D("0.11"),
                }

        with self.subTest(description="nominal case"):
            main.make_order(self.m, 10, "ETH")

            self.m.report.log_stage.assert_has_calls([
                mock.call("make_order_begin"),
                mock.call("make_order_end"),
                ])
            self.m.balances.fetch_balances.assert_has_calls([
                mock.call(tag="make_order_begin"),
                mock.call(tag="make_order_end"),
                ])
            self.m.trades.all.append.assert_called_once()
            trade = self.m.trades.all.append.mock_calls[0][1][0]
            self.assertEqual(False, trade.orders[0].close_if_possible)
            self.assertEqual(0, trade.value_from)
            self.assertEqual("ETH", trade.currency)
            self.assertEqual("BTC", trade.base_currency)
            self.m.report.log_orders.assert_called_once_with([trade.orders[0]], None, "average")
            self.m.trades.run_orders.assert_called_once_with()
            self.m.follow_orders.assert_called_once_with()

            order = trade.orders[0]
            self.assertEqual(D("0.10"), order.rate)

            self.m.reset_mock()
            with self.subTest(compute_value="default"):
                main.make_order(self.m, 10, "ETH", action="dispose",
                        compute_value="ask")

                trade = self.m.trades.all.append.mock_calls[0][1][0]
                order = trade.orders[0]
                self.assertEqual(D("0.11"), order.rate)

        self.m.reset_mock()
        with self.subTest(follow=False):
            result = main.make_order(self.m, 10, "ETH", follow=False)

            self.m.report.log_stage.assert_has_calls([
                mock.call("make_order_begin"),
                mock.call("make_order_end_not_followed"),
                ])
            self.m.balances.fetch_balances.assert_called_once_with(tag="make_order_begin")

            self.m.trades.all.append.assert_called_once()
            trade = self.m.trades.all.append.mock_calls[0][1][0]
            self.assertEqual(0, trade.value_from)
            self.assertEqual("ETH", trade.currency)
            self.assertEqual("BTC", trade.base_currency)
            self.m.report.log_orders.assert_called_once_with([trade.orders[0]], None, "average")
            self.m.trades.run_orders.assert_called_once_with()
            self.m.follow_orders.assert_not_called()
            self.assertEqual(trade.orders[0], result)

        self.m.reset_mock()
        with self.subTest(base_currency="USDT"):
            main.make_order(self.m, 1, "BTC", base_currency="USDT")

            trade = self.m.trades.all.append.mock_calls[0][1][0]
            self.assertEqual("BTC", trade.currency)
            self.assertEqual("USDT", trade.base_currency)

        self.m.reset_mock()
        with self.subTest(close_if_possible=True):
            main.make_order(self.m, 10, "ETH", close_if_possible=True)

            trade = self.m.trades.all.append.mock_calls[0][1][0]
            self.assertEqual(True, trade.orders[0].close_if_possible)

        self.m.reset_mock()
        with self.subTest(action="dispose"):
            main.make_order(self.m, 10, "ETH", action="dispose")

            trade = self.m.trades.all.append.mock_calls[0][1][0]
            self.assertEqual(0, trade.value_to)
            self.assertEqual(1, trade.value_from.value)
            self.assertEqual("ETH", trade.currency)
            self.assertEqual("BTC", trade.base_currency)

            self.m.reset_mock()
            with self.subTest(compute_value="default"):
                main.make_order(self.m, 10, "ETH", action="dispose",
                        compute_value="bid")

                trade = self.m.trades.all.append.mock_calls[0][1][0]
                self.assertEqual(D("0.9"), trade.value_from.value)

    def test_get_user_market(self):
        with mock.patch("main.fetch_markets") as main_fetch_markets,\
                mock.patch("main.parse_args") as main_parse_args,\
                mock.patch("main.parse_config") as main_parse_config:
            with self.subTest(debug=False):
                main_parse_args.return_value = self.market_args()
                main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)]
                m = main.get_user_market("config_path.ini", 1)

                self.assertIsInstance(m, market.Market)
                self.assertFalse(m.debug)
                main_parse_args.assert_called_once_with(["--config", "config_path.ini"])

            main_parse_args.reset_mock()
            with self.subTest(debug=True):
                main_parse_args.return_value = self.market_args(debug=True)
                main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)]
                m = main.get_user_market("config_path.ini", 1, debug=True)

                self.assertIsInstance(m, market.Market)
                self.assertTrue(m.debug)
                main_parse_args.assert_called_once_with(["--config", "config_path.ini", "--debug"])

    def test_process(self):
        with mock.patch("market.Market") as market_mock,\
                mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock:

            args_mock = mock.Mock()
            args_mock.action = "action"
            args_mock.config = "config"
            args_mock.user = "user"
            args_mock.debug = "debug"
            args_mock.before = "before"
            args_mock.after = "after"
            self.assertEqual("", stdout_mock.getvalue())

            main.process("config", 3, 1, args_mock)

            market_mock.from_config.assert_has_calls([
                mock.call("config", args_mock, market_id=3, user_id=1),
                mock.call().process("action", before="before", after="after"),
                ])

            with self.subTest(exception=True):
                market_mock.from_config.side_effect = Exception("boo")
                main.process(3, "config", 1, args_mock)
                self.assertEqual("Exception: boo\n", stdout_mock.getvalue())

    def test_main(self):
        with self.subTest(parallel=False):
            with mock.patch("main.parse_args") as parse_args,\
                    mock.patch("main.parse_config") as parse_config,\
                    mock.patch("main.fetch_markets") as fetch_markets,\
                    mock.patch("main.process") as process:

                args_mock = mock.Mock()
                args_mock.parallel = False
                args_mock.user = "user"
                parse_args.return_value = args_mock

                fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]

                main.main(["Foo", "Bar"])

                parse_args.assert_called_with(["Foo", "Bar"])
                parse_config.assert_called_with(args_mock)
                fetch_markets.assert_called_with("user")

                self.assertEqual(2, process.call_count)
                process.assert_has_calls([
                    mock.call("config1", 3, 1, args_mock),
                    mock.call("config2", 1, 2, args_mock),
                    ])
        with self.subTest(parallel=True):
            with mock.patch("main.parse_args") as parse_args,\
                    mock.patch("main.parse_config") as parse_config,\
                    mock.patch("main.fetch_markets") as fetch_markets,\
                    mock.patch("main.process") as process,\
                    mock.patch("store.Portfolio.start_worker") as start,\
                    mock.patch("store.Portfolio.stop_worker") as stop:

                args_mock = mock.Mock()
                args_mock.parallel = True
                args_mock.user = "user"
                parse_args.return_value = args_mock

                fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]

                main.main(["Foo", "Bar"])

                parse_args.assert_called_with(["Foo", "Bar"])
                parse_config.assert_called_with(args_mock)
                fetch_markets.assert_called_with("user")

                stop.assert_called_once_with()
                start.assert_called_once_with()
                self.assertEqual(2, process.call_count)
                process.assert_has_calls([
                    mock.call.__bool__(),
                    mock.call("config1", 3, 1, args_mock),
                    mock.call.__bool__(),
                    mock.call("config2", 1, 2, args_mock),
                    ])
        with self.subTest(quiet=True):
            with mock.patch("main.parse_args") as parse_args,\
                    mock.patch("main.parse_config") as parse_config,\
                    mock.patch("main.fetch_markets") as fetch_markets,\
                    mock.patch("store.Portfolio.report") as report,\
                    mock.patch("main.process") as process:

                args_mock = mock.Mock()
                args_mock.parallel = False
                args_mock.quiet = True
                args_mock.user = "user"
                parse_args.return_value = args_mock

                fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]

                main.main(["Foo", "Bar"])

                report.set_verbose.assert_called_once_with(False)

        with self.subTest(quiet=False):
            with mock.patch("main.parse_args") as parse_args,\
                    mock.patch("main.parse_config") as parse_config,\
                    mock.patch("main.fetch_markets") as fetch_markets,\
                    mock.patch("store.Portfolio.report") as report,\
                    mock.patch("main.process") as process:

                args_mock = mock.Mock()
                args_mock.parallel = False
                args_mock.quiet = False
                args_mock.user = "user"
                parse_args.return_value = args_mock

                fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]

                main.main(["Foo", "Bar"])

                report.set_verbose.assert_called_once_with(True)


    @mock.patch.object(main.sys, "exit")
    @mock.patch("main.os")
    def test_parse_config(self, os, exit):
        with self.subTest(report_path=None),\
                mock.patch.object(main.dbs, "connect_psql") as psql,\
                mock.patch.object(main.dbs, "connect_redis") as redis:
            args = main.configargparse.Namespace(**{
                "db_host": "host",
                "redis_host": "rhost",
                "report_path": None,
                })

            main.parse_config(args)
            psql.assert_called_once_with(args)
            redis.assert_called_once_with(args)

        with self.subTest(report_path=None, db=None),\
                mock.patch.object(main.dbs, "connect_psql") as psql,\
                mock.patch.object(main.dbs, "connect_redis") as redis:
            args = main.configargparse.Namespace(**{
                "db_host": None,
                "redis_host": "rhost",
                "report_path": None,
                })

            main.parse_config(args)
            psql.assert_not_called()
            redis.assert_called_once_with(args)

        with self.subTest(report_path=None, redis=None),\
                mock.patch.object(main.dbs, "connect_psql") as psql,\
                mock.patch.object(main.dbs, "connect_redis") as redis:
            args = main.configargparse.Namespace(**{
                "db_host": "host",
                "redis_host": None,
                "report_path": None,
                })

            main.parse_config(args)
            redis.assert_not_called()
            psql.assert_called_once_with(args)

        with self.subTest(report_path="present"),\
                mock.patch.object(main.dbs, "connect_psql") as psql,\
                mock.patch.object(main.dbs, "connect_redis") as redis:
            args = main.configargparse.Namespace(**{
                "db_host": "host",
                "redis_host": "rhost",
                "report_path": "report_path",
                })

            os.path.exists.return_value = False

            main.parse_config(args)

            os.path.exists.assert_called_once_with("report_path")
            os.makedirs.assert_called_once_with("report_path")

    def test_parse_args(self):
        with self.subTest(config="config.ini"):
            args = main.parse_args([])
            self.assertEqual("config.ini", args.config)
            self.assertFalse(args.before)
            self.assertFalse(args.after)
            self.assertFalse(args.debug)

            args = main.parse_args(["--before", "--after", "--debug"])
            self.assertTrue(args.before)
            self.assertTrue(args.after)
            self.assertTrue(args.debug)

        with self.subTest(config="inexistant"), \
                self.assertRaises(SystemExit), \
                mock.patch('sys.stderr', new_callable=StringIO) as stdout_mock:
            args = main.parse_args(["--config", "foo.bar"])

    @mock.patch.object(main.dbs, "psql")
    def test_fetch_markets(self, psql):
        cursor_mock = mock.MagicMock()
        cursor_mock.__iter__.return_value = ["row_1", "row_2"]

        psql.cursor.return_value = cursor_mock

        with self.subTest(user=None):
            rows = list(main.fetch_markets(None))

            cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs")

            self.assertEqual(["row_1", "row_2"], rows)

        cursor_mock.execute.reset_mock()
        with self.subTest(user=1):
            rows = list(main.fetch_markets(1))

            cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs WHERE user_id = %s", 1)

            self.assertEqual(["row_1", "row_2"], rows)