]> git.immae.eu Git - perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git/blobdiff - tests/test_main.py
Refactor databases access
[perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git] / tests / test_main.py
index 55b1382551aa34da39ec9b35fcd8b4bd7e8c7ee5..1864c062af75fd1bd7d7c0eabca80127e6a2dc13 100644 (file)
@@ -103,7 +103,6 @@ class MainTest(WebMockTestCase):
                 mock.patch("main.parse_config") as main_parse_config:
             with self.subTest(debug=False):
                 main_parse_args.return_value = self.market_args()
-                main_parse_config.return_value = ["pg_config", "redis_config"]
                 main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)]
                 m = main.get_user_market("config_path.ini", 1)
 
@@ -114,7 +113,6 @@ class MainTest(WebMockTestCase):
             main_parse_args.reset_mock()
             with self.subTest(debug=True):
                 main_parse_args.return_value = self.market_args(debug=True)
-                main_parse_config.return_value = ["pg_config", "redis_config"]
                 main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)]
                 m = main.get_user_market("config_path.ini", 1, debug=True)
 
@@ -135,16 +133,16 @@ class MainTest(WebMockTestCase):
             args_mock.after = "after"
             self.assertEqual("", stdout_mock.getvalue())
 
-            main.process("config", 3, 1, args_mock, "pg_config", "redis_config")
+            main.process("config", 3, 1, args_mock)
 
             market_mock.from_config.assert_has_calls([
-                mock.call("config", args_mock, pg_config="pg_config", redis_config="redis_config", market_id=3, user_id=1),
+                mock.call("config", args_mock, market_id=3, user_id=1),
                 mock.call().process("action", before="before", after="after"),
                 ])
 
             with self.subTest(exception=True):
                 market_mock.from_config.side_effect = Exception("boo")
-                main.process(3, "config", 1, args_mock, "pg_config", "redis_config")
+                main.process(3, "config", 1, args_mock)
                 self.assertEqual("Exception: boo\n", stdout_mock.getvalue())
 
     def test_main(self):
@@ -159,20 +157,18 @@ class MainTest(WebMockTestCase):
                 args_mock.user = "user"
                 parse_args.return_value = args_mock
 
-                parse_config.return_value = ["pg_config", "redis_config"]
-
                 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]
 
                 main.main(["Foo", "Bar"])
 
                 parse_args.assert_called_with(["Foo", "Bar"])
                 parse_config.assert_called_with(args_mock)
-                fetch_markets.assert_called_with("pg_config", "user")
+                fetch_markets.assert_called_with("user")
 
                 self.assertEqual(2, process.call_count)
                 process.assert_has_calls([
-                    mock.call("config1", 3, 1, args_mock, "pg_config", "redis_config"),
-                    mock.call("config2", 1, 2, args_mock, "pg_config", "redis_config"),
+                    mock.call("config1", 3, 1, args_mock),
+                    mock.call("config2", 1, 2, args_mock),
                     ])
         with self.subTest(parallel=True):
             with mock.patch("main.parse_args") as parse_args,\
@@ -187,24 +183,22 @@ class MainTest(WebMockTestCase):
                 args_mock.user = "user"
                 parse_args.return_value = args_mock
 
-                parse_config.return_value = ["pg_config", "redis_config"]
-
                 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]
 
                 main.main(["Foo", "Bar"])
 
                 parse_args.assert_called_with(["Foo", "Bar"])
                 parse_config.assert_called_with(args_mock)
-                fetch_markets.assert_called_with("pg_config", "user")
+                fetch_markets.assert_called_with("user")
 
                 stop.assert_called_once_with()
                 start.assert_called_once_with()
                 self.assertEqual(2, process.call_count)
                 process.assert_has_calls([
                     mock.call.__bool__(),
-                    mock.call("config1", 3, 1, args_mock, "pg_config", "redis_config"),
+                    mock.call("config1", 3, 1, args_mock),
                     mock.call.__bool__(),
-                    mock.call("config2", 1, 2, args_mock, "pg_config", "redis_config"),
+                    mock.call("config2", 1, 2, args_mock),
                     ])
         with self.subTest(quiet=True):
             with mock.patch("main.parse_args") as parse_args,\
@@ -219,8 +213,6 @@ class MainTest(WebMockTestCase):
                 args_mock.user = "user"
                 parse_args.return_value = args_mock
 
-                parse_config.return_value = ["pg_config", "redis_config"]
-
                 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]
 
                 main.main(["Foo", "Bar"])
@@ -240,8 +232,6 @@ class MainTest(WebMockTestCase):
                 args_mock.user = "user"
                 parse_args.return_value = args_mock
 
-                parse_config.return_value = ["pg_config", "redis_config"]
-
                 fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]
 
                 main.main(["Foo", "Bar"])
@@ -252,63 +242,57 @@ class MainTest(WebMockTestCase):
     @mock.patch.object(main.sys, "exit")
     @mock.patch("main.os")
     def test_parse_config(self, os, exit):
-        with self.subTest(report_path=None):
+        with self.subTest(report_path=None),\
+                mock.patch.object(main.dbs, "connect_psql") as psql,\
+                mock.patch.object(main.dbs, "connect_redis") as redis:
             args = main.configargparse.Namespace(**{
                 "db_host": "host",
-                "db_port": "port",
-                "db_user": "user",
-                "db_password": "password",
-                "db_database": "database",
                 "redis_host": "rhost",
-                "redis_port": "rport",
-                "redis_database": "rdb",
                 "report_path": None,
                 })
 
-            db_config, redis_config = main.parse_config(args)
-            self.assertEqual({ "host": "host", "port": "port", "user":
-                "user", "password": "password", "database": "database"
-                }, db_config)
-            self.assertEqual({ "host": "rhost", "port": "rport", "db":
-                "rdb"}, redis_config)
+            main.parse_config(args)
+            psql.assert_called_once_with(args)
+            redis.assert_called_once_with(args)
+
+        with self.subTest(report_path=None, db=None),\
+                mock.patch.object(main.dbs, "connect_psql") as psql,\
+                mock.patch.object(main.dbs, "connect_redis") as redis:
+            args = main.configargparse.Namespace(**{
+                "db_host": None,
+                "redis_host": "rhost",
+                "report_path": None,
+                })
 
-            with self.assertRaises(AttributeError):
-                args.db_password
-            with self.assertRaises(AttributeError):
-                args.redis_host
+            main.parse_config(args)
+            psql.assert_not_called()
+            redis.assert_called_once_with(args)
 
-        with self.subTest(redis_host="socket"):
+        with self.subTest(report_path=None, redis=None),\
+                mock.patch.object(main.dbs, "connect_psql") as psql,\
+                mock.patch.object(main.dbs, "connect_redis") as redis:
             args = main.configargparse.Namespace(**{
                 "db_host": "host",
-                "db_port": "port",
-                "db_user": "user",
-                "db_password": "password",
-                "db_database": "database",
-                "redis_host": "/run/foo",
-                "redis_port": "rport",
-                "redis_database": "rdb",
+                "redis_host": None,
                 "report_path": None,
                 })
 
-            db_config, redis_config = main.parse_config(args)
-            self.assertEqual({ "unix_socket_path": "/run/foo", "db": "rdb"}, redis_config)
+            main.parse_config(args)
+            redis.assert_not_called()
+            psql.assert_called_once_with(args)
 
-        with self.subTest(report_path="present"):
+        with self.subTest(report_path="present"),\
+                mock.patch.object(main.dbs, "connect_psql") as psql,\
+                mock.patch.object(main.dbs, "connect_redis") as redis:
             args = main.configargparse.Namespace(**{
                 "db_host": "host",
-                "db_port": "port",
-                "db_user": "user",
-                "db_password": "password",
-                "db_database": "database",
                 "redis_host": "rhost",
-                "redis_port": "rport",
-                "redis_database": "rdb",
                 "report_path": "report_path",
                 })
 
             os.path.exists.return_value = False
 
-            result = main.parse_config(args)
+            main.parse_config(args)
 
             os.path.exists.assert_called_once_with("report_path")
             os.makedirs.assert_called_once_with("report_path")
@@ -331,29 +315,24 @@ class MainTest(WebMockTestCase):
                 mock.patch('sys.stderr', new_callable=StringIO) as stdout_mock:
             args = main.parse_args(["--config", "foo.bar"])
 
-    @mock.patch.object(main, "psycopg2")
-    def test_fetch_markets(self, psycopg2):
-        connect_mock = mock.Mock()
+    @mock.patch.object(main.dbs, "psql")
+    def test_fetch_markets(self, psql):
         cursor_mock = mock.MagicMock()
         cursor_mock.__iter__.return_value = ["row_1", "row_2"]
 
-        connect_mock.cursor.return_value = cursor_mock
-        psycopg2.connect.return_value = connect_mock
+        psql.cursor.return_value = cursor_mock
 
         with self.subTest(user=None):
-            rows = list(main.fetch_markets({"foo": "bar"}, None))
+            rows = list(main.fetch_markets(None))
 
-            psycopg2.connect.assert_called_once_with(foo="bar")
             cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs")
 
             self.assertEqual(["row_1", "row_2"], rows)
 
-        psycopg2.connect.reset_mock()
         cursor_mock.execute.reset_mock()
         with self.subTest(user=1):
-            rows = list(main.fetch_markets({"foo": "bar"}, 1))
+            rows = list(main.fetch_markets(1))
 
-            psycopg2.connect.assert_called_once_with(foo="bar")
             cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs WHERE user_id = %s", 1)
 
             self.assertEqual(["row_1", "row_2"], rows)