+from .helper import *
+import dbs, main
+
+@unittest.skipUnless("unit" in limits, "Unit skipped")
+class DbsTest(WebMockTestCase):
+ @mock.patch.object(dbs, "psycopg2")
+ def test_connect_psql(self, psycopg2):
+ args = main.configargparse.Namespace(**{
+ "db_host": "host",
+ "db_port": "port",
+ "db_user": "user",
+ "db_password": "password",
+ "db_database": "database",
+ })
+ psycopg2.connect.return_value = "pg_connection"
+ dbs.connect_psql(args)
+
+ psycopg2.connect.assert_called_once_with(host="host",
+ port="port", user="user", password="password",
+ database="database")
+ self.assertEqual("pg_connection", dbs.psql)
+ with self.assertRaises(AttributeError):
+ args.db_password
+
+ psycopg2.connect.reset_mock()
+ args = main.configargparse.Namespace(**{
+ "db_host": "host",
+ "db_port": "port",
+ "db_user": "user",
+ "db_password": "password",
+ "db_database": "database",
+ })
+ dbs.connect_psql(args)
+ psycopg2.connect.assert_not_called()
+
+ @mock.patch.object(dbs, "_redis")
+ def test_connect_redis(self, redis):
+ with self.subTest(redis_host="tcp"):
+ args = main.configargparse.Namespace(**{
+ "redis_host": "host",
+ "redis_port": "port",
+ "redis_database": "database",
+ })
+ redis.Redis.return_value = "redis_connection"
+ dbs.connect_redis(args)
+
+ redis.Redis.assert_called_once_with(host="host",
+ port="port", db="database")
+ self.assertEqual("redis_connection", dbs.redis)
+ with self.assertRaises(AttributeError):
+ args.redis_database
+
+ redis.Redis.reset_mock()
+ args = main.configargparse.Namespace(**{
+ "redis_host": "host",
+ "redis_port": "port",
+ "redis_database": "database",
+ })
+ dbs.connect_redis(args)
+ redis.Redis.assert_not_called()
+
+ dbs.redis = None
+ with self.subTest(redis_host="socket"):
+ args = main.configargparse.Namespace(**{
+ "redis_host": "/run/foo",
+ "redis_port": "port",
+ "redis_database": "database",
+ })
+ redis.Redis.return_value = "redis_socket"
+ dbs.connect_redis(args)
+
+ redis.Redis.assert_called_once_with(unix_socket_path="/run/foo", db="database")
+ self.assertEqual("redis_socket", dbs.redis)
+
+ def test_redis_connected(self):
+ with self.subTest(redis=None):
+ dbs.redis = None
+ self.assertFalse(dbs.redis_connected())
+
+ with self.subTest(redis="mocked_true"):
+ dbs.redis = mock.Mock()
+ dbs.redis.ping.return_value = True
+ self.assertTrue(dbs.redis_connected())
+
+ with self.subTest(redis="mocked_false"):
+ dbs.redis = mock.Mock()
+ dbs.redis.ping.return_value = False
+ self.assertFalse(dbs.redis_connected())
+
+ with self.subTest(redis="mocked_raise"):
+ dbs.redis = mock.Mock()
+ dbs.redis.ping.side_effect = Exception("bouh")
+ self.assertFalse(dbs.redis_connected())
+
+ def test_psql_connected(self):
+ with self.subTest(psql=None):
+ dbs.psql = None
+ self.assertFalse(dbs.psql_connected())
+
+ with self.subTest(psql="connected"):
+ dbs.psql = mock.Mock()
+ dbs.psql.closed = 0
+ self.assertTrue(dbs.psql_connected())
+
+ with self.subTest(psql="not connected"):
+ dbs.psql = mock.Mock()
+ dbs.psql.closed = 3
+ self.assertFalse(dbs.psql_connected())