]> git.immae.eu Git - perso/Immae/Projets/Cryptomonnaies/Cryptoportfolio/Trader.git/commitdiff
Store reports to database
authorIsmaël Bouya <ismael.bouya@normalesup.org>
Fri, 23 Mar 2018 00:11:34 +0000 (01:11 +0100)
committerIsmaël Bouya <ismael.bouya@normalesup.org>
Sat, 24 Mar 2018 09:39:35 +0000 (10:39 +0100)
Fixes https://git.immae.eu/mantisbt/view.php?id=57

main.py
market.py
store.py
test.py

diff --git a/main.py b/main.py
index 55981bf94d82c966aea93e3f89ea25a01cdb7c79..3e9828952867dc7034fb588d865d713655d1a112 100644 (file)
--- a/main.py
+++ b/main.py
@@ -62,7 +62,7 @@ def make_order(market, value, currency, action="acquire",
 
 def get_user_market(config_path, user_id, debug=False):
     pg_config, report_path = parse_config(config_path)
-    market_config = list(fetch_markets(pg_config, str(user_id)))[0][0]
+    market_config = list(fetch_markets(pg_config, str(user_id)))[0][1]
     args = type('Args', (object,), { "debug": debug, "quiet": False })()
     return market.Market.from_config(market_config, args, user_id=user_id, report_path=report_path)
 
@@ -71,9 +71,9 @@ def fetch_markets(pg_config, user):
     cursor = connection.cursor()
 
     if user is None:
-        cursor.execute("SELECT config,user_id FROM market_configs")
+        cursor.execute("SELECT id,config,user_id FROM market_configs")
     else:
-        cursor.execute("SELECT config,user_id FROM market_configs WHERE user_id = %s", user)
+        cursor.execute("SELECT id,config,user_id FROM market_configs WHERE user_id = %s", user)
 
     for row in cursor:
         yield row
@@ -132,10 +132,12 @@ def parse_args(argv):
 
     return args
 
-def process(market_config, user_id, report_path, args):
+def process(market_id, market_config, user_id, report_path, args, pg_config):
     try:
         market.Market\
-                .from_config(market_config, args, user_id=user_id, report_path=report_path)\
+                .from_config(market_config, args,
+                        pg_config=pg_config, market_id=market_id,
+                        user_id=user_id, report_path=report_path)\
                 .process(args.action, before=args.before, after=args.after)
     except Exception as e:
         print("{}: {}".format(e.__class__.__name__, e))
@@ -149,11 +151,13 @@ def main(argv):
         import threading
         market.Portfolio.start_worker()
 
-        for market_config, user_id in fetch_markets(pg_config, args.user):
-            threading.Thread(target=process, args=[market_config, user_id, report_path, args]).start()
+        for row in fetch_markets(pg_config, args.user):
+            threading.Thread(target=process, args=[
+                *row, report_path, args, pg_config
+                ]).start()
     else:
-        for market_config, user_id in fetch_markets(pg_config, args.user):
-            process(market_config, user_id, report_path, args)
+        for row in fetch_markets(pg_config, args.user):
+            process(*row, report_path, args, pg_config)
 
 if __name__ == '__main__': # pragma: no cover
     main(sys.argv[1:])
index fc5832c089fc6eb63a97ec1621dab65eca406a2e..78ced1a209eea10c181dfd429b238ff7ca30c659 100644 (file)
--- a/market.py
+++ b/market.py
@@ -1,6 +1,7 @@
 from ccxt import ExchangeError, NotSupported
 import ccxt_wrapper as ccxt
 import time
+import psycopg2
 from store import *
 from cachetools.func import ttl_cache
 from datetime import datetime
@@ -13,7 +14,9 @@ class Market:
     trades = None
     balances = None
 
-    def __init__(self, ccxt_instance, args, user_id=None, report_path=None):
+    def __init__(self, ccxt_instance, args,
+            user_id=None, market_id=None,
+            report_path=None, pg_config=None):
         self.args = args
         self.debug = args.debug
         self.ccxt = ccxt_instance
@@ -24,10 +27,13 @@ class Market:
         self.processor = Processor(self)
 
         self.user_id = user_id
+        self.market_id = market_id
         self.report_path = report_path
+        self.pg_config = pg_config
 
     @classmethod
-    def from_config(cls, config, args, user_id=None, report_path=None):
+    def from_config(cls, config, args,
+            user_id=None, market_id=None, report_path=None, pg_config=None):
         config["apiKey"] = config.pop("key", None)
 
         ccxt_instance = ccxt.poloniexE(config)
@@ -44,20 +50,45 @@ class Market:
         ccxt_instance.session.request = request_wrap.__get__(ccxt_instance.session,
                 ccxt_instance.session.__class__)
 
-        return cls(ccxt_instance, args, user_id=user_id, report_path=report_path)
+        return cls(ccxt_instance, args,
+                user_id=user_id, market_id=market_id,
+                pg_config=pg_config, report_path=report_path)
 
     def store_report(self):
         self.report.merge(Portfolio.report)
+        date = datetime.now()
+        if self.report_path is not None:
+            self.store_file_report(date)
+        if self.pg_config is not None:
+            self.store_database_report(date)
+
+    def store_file_report(self, date):
         try:
-            if self.report_path is not None:
-                report_file = "{}/{}_{}".format(self.report_path, datetime.now().isoformat(), self.user_id)
-                with open(report_file + ".json", "w") as f:
-                    f.write(self.report.to_json())
-                with open(report_file + ".log", "w") as f:
-                    f.write("\n".join(map(lambda x: x[1], self.report.print_logs)))
+            report_file = "{}/{}_{}".format(self.report_path, date.isoformat(), self.user_id)
+            with open(report_file + ".json", "w") as f:
+                f.write(self.report.to_json())
+            with open(report_file + ".log", "w") as f:
+                f.write("\n".join(map(lambda x: x[1], self.report.print_logs)))
         except Exception as e:
             print("impossible to store report file: {}; {}".format(e.__class__.__name__, e))
 
+    def store_database_report(self, date):
+        try:
+            report_query = 'INSERT INTO reports("date", "market_config_id", "debug") VALUES (%s, %s, %s) RETURNING id;'
+            line_query = 'INSERT INTO report_lines("date", "report_id", "type", "payload") VALUES (%s, %s, %s, %s);'
+            connection = psycopg2.connect(**self.pg_config)
+            cursor = connection.cursor()
+            cursor.execute(report_query, (date, self.market_id, self.debug))
+            report_id = cursor.fetchone()[0]
+            for date, type_, payload in self.report.to_json_array():
+                cursor.execute(line_query, (date, report_id, type_, payload))
+
+            connection.commit()
+            cursor.close()
+            connection.close()
+        except Exception as e:
+            print("impossible to store report to database: {}; {}".format(e.__class__.__name__, e))
+
     def process(self, actions, before=False, after=False):
         try:
             if len(actions or []) == 0:
index d875a983a77b5d0fa8cbce0f60c923ac5de62832..b3ada4567e38a28613efd0159dafd33ece105a0d 100644 (file)
--- a/store.py
+++ b/store.py
@@ -36,12 +36,22 @@ class ReportStore:
         hash_["date"] = datetime.now()
         self.logs.append(hash_)
 
+    @staticmethod
+    def default_json_serial(obj):
+        if isinstance(obj, (datetime, date)):
+            return obj.isoformat()
+        return str(obj)
+
     def to_json(self):
-        def default_json_serial(obj):
-            if isinstance(obj, (datetime, date)):
-                return obj.isoformat()
-            return str(obj)
-        return json.dumps(self.logs, default=default_json_serial, indent="  ")
+        return json.dumps(self.logs, default=self.default_json_serial, indent="  ")
+
+    def to_json_array(self):
+        for log in (x.copy() for x in self.logs):
+            yield (
+                    log.pop("date"),
+                    log.pop("type"),
+                    json.dumps(log, default=self.default_json_serial, indent="  ")
+                    )
 
     def set_verbose(self, verbose_print):
         self.verbose_print = verbose_print
diff --git a/test.py b/test.py
index 3ee34c69a4ff974e428854107527c4070b858be7..5b9c56c02effd0f7a983f78b7c9b161f614e38bd 100644 (file)
--- a/test.py
+++ b/test.py
@@ -1386,18 +1386,7 @@ class MarketTest(WebMockTestCase):
                     self.ccxt.transfer_balance.assert_any_call("USDT", 100, "exchange", "margin")
                     self.ccxt.transfer_balance.assert_any_call("ETC", 5, "margin", "exchange")
 
-    def test_store_report(self):
-
-        file_open = mock.mock_open()
-        m = market.Market(self.ccxt, self.market_args(), user_id=1)
-        with self.subTest(file=None),\
-                mock.patch.object(m, "report") as report,\
-                mock.patch("market.open", file_open):
-            m.store_report()
-            report.merge.assert_called_with(store.Portfolio.report)
-            file_open.assert_not_called()
-
-        report.reset_mock()
+    def test_store_file_report(self):
         file_open = mock.mock_open()
         m = market.Market(self.ccxt, self.market_args(), report_path="present", user_id=1)
         with self.subTest(file="present"),\
@@ -1405,20 +1394,16 @@ class MarketTest(WebMockTestCase):
                 mock.patch.object(m, "report") as report,\
                 mock.patch.object(market, "datetime") as time_mock:
 
-            time_mock.now.return_value = datetime.datetime(2018, 2, 25)
             report.print_logs = [[time_mock.now(), "Foo"], [time_mock.now(), "Bar"]]
             report.to_json.return_value = "json_content"
 
-            m.store_report()
+            m.store_file_report(datetime.datetime(2018, 2, 25))
 
             file_open.assert_any_call("present/2018-02-25T00:00:00_1.json", "w")
             file_open.assert_any_call("present/2018-02-25T00:00:00_1.log", "w")
             file_open().write.assert_any_call("json_content")
             file_open().write.assert_any_call("Foo\nBar")
             m.report.to_json.assert_called_once_with()
-            report.merge.assert_called_with(store.Portfolio.report)
-
-        report.reset_mock()
 
         m = market.Market(self.ccxt, self.market_args(), report_path="error", user_id=1)
         with self.subTest(file="error"),\
@@ -1427,10 +1412,106 @@ class MarketTest(WebMockTestCase):
                 mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock:
             file_open.side_effect = FileNotFoundError
 
+            m.store_file_report(datetime.datetime(2018, 2, 25))
+
+            self.assertRegex(stdout_mock.getvalue(), "impossible to store report file: FileNotFoundError;")
+
+    @mock.patch.object(market, "psycopg2")
+    def test_store_database_report(self, psycopg2):
+        connect_mock = mock.Mock()
+        cursor_mock = mock.MagicMock()
+
+        connect_mock.cursor.return_value = cursor_mock
+        psycopg2.connect.return_value = connect_mock
+        m = market.Market(self.ccxt, self.market_args(),
+                pg_config={"config": "pg_config"}, user_id=1)
+        cursor_mock.fetchone.return_value = [42]
+
+        with self.subTest(error=False),\
+                mock.patch.object(m, "report") as report:
+            report.to_json_array.return_value = [
+                    ("date1", "type1", "payload1"),
+                    ("date2", "type2", "payload2"),
+                    ]
+            m.store_database_report(datetime.datetime(2018, 3, 24))
+            connect_mock.assert_has_calls([
+                mock.call.cursor(),
+                mock.call.cursor().execute('INSERT INTO reports("date", "market_config_id", "debug") VALUES (%s, %s, %s) RETURNING id;', (datetime.datetime(2018, 3, 24), None, False)),
+                mock.call.cursor().fetchone(),
+                mock.call.cursor().execute('INSERT INTO report_lines("date", "report_id", "type", "payload") VALUES (%s, %s, %s, %s);', ('date1', 42, 'type1', 'payload1')),
+                mock.call.cursor().execute('INSERT INTO report_lines("date", "report_id", "type", "payload") VALUES (%s, %s, %s, %s);', ('date2', 42, 'type2', 'payload2')),
+                mock.call.commit(),
+                mock.call.cursor().close(),
+                mock.call.close()
+                ])
+
+        connect_mock.reset_mock()
+        with self.subTest(error=True),\
+                mock.patch('sys.stdout', new_callable=StringIO) as stdout_mock:
+            psycopg2.connect.side_effect = Exception("Bouh")
+            m.store_database_report(datetime.datetime(2018, 3, 24))
+            self.assertEqual(stdout_mock.getvalue(), "impossible to store report to database: Exception; Bouh\n")
+
+    def test_store_report(self):
+        m = market.Market(self.ccxt, self.market_args(), user_id=1)
+        with self.subTest(file=None, pg_config=None),\
+                mock.patch.object(m, "report") as report,\
+                mock.patch.object(m, "store_database_report") as db_report,\
+                mock.patch.object(m, "store_file_report") as file_report:
+            m.store_report()
+            report.merge.assert_called_with(store.Portfolio.report)
+
+            file_report.assert_not_called()
+            db_report.assert_not_called()
+
+        report.reset_mock()
+        m = market.Market(self.ccxt, self.market_args(), report_path="present", user_id=1)
+        with self.subTest(file="present", pg_config=None),\
+                mock.patch.object(m, "report") as report,\
+                mock.patch.object(m, "store_file_report") as file_report,\
+                mock.patch.object(m, "store_database_report") as db_report,\
+                mock.patch.object(market, "datetime") as time_mock:
+
+            time_mock.now.return_value = datetime.datetime(2018, 2, 25)
+
             m.store_report()
 
             report.merge.assert_called_with(store.Portfolio.report)
-            self.assertRegex(stdout_mock.getvalue(), "impossible to store report file: FileNotFoundError;")
+            file_report.assert_called_once_with(datetime.datetime(2018, 2, 25))
+            db_report.assert_not_called()
+
+        report.reset_mock()
+        m = market.Market(self.ccxt, self.market_args(), pg_config="present", user_id=1)
+        with self.subTest(file=None, pg_config="present"),\
+                mock.patch.object(m, "report") as report,\
+                mock.patch.object(m, "store_file_report") as file_report,\
+                mock.patch.object(m, "store_database_report") as db_report,\
+                mock.patch.object(market, "datetime") as time_mock:
+
+            time_mock.now.return_value = datetime.datetime(2018, 2, 25)
+
+            m.store_report()
+
+            report.merge.assert_called_with(store.Portfolio.report)
+            file_report.assert_not_called()
+            db_report.assert_called_once_with(datetime.datetime(2018, 2, 25))
+
+        report.reset_mock()
+        m = market.Market(self.ccxt, self.market_args(),
+                pg_config="pg_config", report_path="present", user_id=1)
+        with self.subTest(file="present", pg_config="present"),\
+                mock.patch.object(m, "report") as report,\
+                mock.patch.object(m, "store_file_report") as file_report,\
+                mock.patch.object(m, "store_database_report") as db_report,\
+                mock.patch.object(market, "datetime") as time_mock:
+
+            time_mock.now.return_value = datetime.datetime(2018, 2, 25)
+
+            m.store_report()
+
+            report.merge.assert_called_with(store.Portfolio.report)
+            file_report.assert_called_once_with(datetime.datetime(2018, 2, 25))
+            db_report.assert_called_once_with(datetime.datetime(2018, 2, 25))
 
     def test_print_orders(self):
         m = market.Market(self.ccxt, self.market_args())
@@ -3050,6 +3131,14 @@ class ReportStoreTest(WebMockTestCase):
             report_store.print_log(portfolio.Amount("BTC", 1))
             self.assertEqual(stdout_mock.getvalue(), "")
 
+    def test_default_json_serial(self):
+        report_store = market.ReportStore(self.m)
+
+        self.assertEqual("2018-02-24T00:00:00",
+                report_store.default_json_serial(portfolio.datetime(2018, 2, 24)))
+        self.assertEqual("1.00000000 BTC",
+                report_store.default_json_serial(portfolio.Amount("BTC", 1)))
+
     def test_to_json(self):
         report_store = market.ReportStore(self.m)
         report_store.logs.append({"foo": "bar"})
@@ -3059,6 +3148,20 @@ class ReportStoreTest(WebMockTestCase):
         report_store.logs.append({"amount": portfolio.Amount("BTC", 1)})
         self.assertEqual('[\n  {\n    "foo": "bar"\n  },\n  {\n    "date": "2018-02-24T00:00:00"\n  },\n  {\n    "amount": "1.00000000 BTC"\n  }\n]', report_store.to_json())
 
+    def test_to_json_array(self):
+        report_store = market.ReportStore(self.m)
+        report_store.logs.append({
+            "date": "date1", "type": "type1", "foo": "bar", "bla": "bla"
+            })
+        report_store.logs.append({
+            "date": "date2", "type": "type2", "foo": "bar", "bla": "bla"
+            })
+        logs = list(report_store.to_json_array())
+
+        self.assertEqual(2, len(logs))
+        self.assertEqual(("date1", "type1", '{\n  "foo": "bar",\n  "bla": "bla"\n}'), logs[0])
+        self.assertEqual(("date2", "type2", '{\n  "foo": "bar",\n  "bla": "bla"\n}'), logs[1])
+
     @mock.patch.object(market.ReportStore, "print_log")
     @mock.patch.object(market.ReportStore, "add_log")
     def test_log_stage(self, add_log, print_log):
@@ -3552,7 +3655,7 @@ class MainTest(WebMockTestCase):
                 mock.patch("main.parse_config") as main_parse_config:
             with self.subTest(debug=False):
                 main_parse_config.return_value = ["pg_config", "report_path"]
-                main_fetch_markets.return_value = [({"key": "market_config"},)]
+                main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)]
                 m = main.get_user_market("config_path.ini", 1)
 
                 self.assertIsInstance(m, market.Market)
@@ -3560,7 +3663,7 @@ class MainTest(WebMockTestCase):
 
             with self.subTest(debug=True):
                 main_parse_config.return_value = ["pg_config", "report_path"]
-                main_fetch_markets.return_value = [({"key": "market_config"},)]
+                main_fetch_markets.return_value = [(1, {"key": "market_config"}, 3)]
                 m = main.get_user_market("config_path.ini", 1, debug=True)
 
                 self.assertIsInstance(m, market.Market)
@@ -3579,16 +3682,16 @@ class MainTest(WebMockTestCase):
             args_mock.after = "after"
             self.assertEqual("", stdout_mock.getvalue())
 
-            main.process("config", 1, "report_path", args_mock)
+            main.process(3, "config", 1, "report_path", args_mock, "pg_config")
 
             market_mock.from_config.assert_has_calls([
-                mock.call("config", args_mock, user_id=1, report_path="report_path"),
+                mock.call("config", args_mock, pg_config="pg_config", market_id=3, user_id=1, report_path="report_path"),
                 mock.call().process("action", before="before", after="after"),
                 ])
 
             with self.subTest(exception=True):
                 market_mock.from_config.side_effect = Exception("boo")
-                main.process("config", 1, "report_path", args_mock)
+                main.process(3, "config", 1, "report_path", args_mock, "pg_config")
                 self.assertEqual("Exception: boo\n", stdout_mock.getvalue())
 
     def test_main(self):
@@ -3606,7 +3709,7 @@ class MainTest(WebMockTestCase):
 
                 parse_config.return_value = ["pg_config", "report_path"]
 
-                fetch_markets.return_value = [["config1", 1], ["config2", 2]]
+                fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]
 
                 main.main(["Foo", "Bar"])
 
@@ -3616,8 +3719,8 @@ class MainTest(WebMockTestCase):
 
                 self.assertEqual(2, process.call_count)
                 process.assert_has_calls([
-                    mock.call("config1", 1, "report_path", args_mock),
-                    mock.call("config2", 2, "report_path", args_mock),
+                    mock.call(3, "config1", 1, "report_path", args_mock, "pg_config"),
+                    mock.call(1, "config2", 2, "report_path", args_mock, "pg_config"),
                     ])
         with self.subTest(parallel=True):
             with mock.patch("main.parse_args") as parse_args,\
@@ -3634,7 +3737,7 @@ class MainTest(WebMockTestCase):
 
                 parse_config.return_value = ["pg_config", "report_path"]
 
-                fetch_markets.return_value = [["config1", 1], ["config2", 2]]
+                fetch_markets.return_value = [[3, "config1", 1], [1, "config2", 2]]
 
                 main.main(["Foo", "Bar"])
 
@@ -3646,9 +3749,9 @@ class MainTest(WebMockTestCase):
                 self.assertEqual(2, process.call_count)
                 process.assert_has_calls([
                     mock.call.__bool__(),
-                    mock.call("config1", 1, "report_path", args_mock),
+                    mock.call(3, "config1", 1, "report_path", args_mock, "pg_config"),
                     mock.call.__bool__(),
-                    mock.call("config2", 2, "report_path", args_mock),
+                    mock.call(1, "config2", 2, "report_path", args_mock, "pg_config"),
                     ])
 
     @mock.patch.object(main.sys, "exit")
@@ -3734,7 +3837,7 @@ class MainTest(WebMockTestCase):
             rows = list(main.fetch_markets({"foo": "bar"}, None))
 
             psycopg2.connect.assert_called_once_with(foo="bar")
-            cursor_mock.execute.assert_called_once_with("SELECT config,user_id FROM market_configs")
+            cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs")
 
             self.assertEqual(["row_1", "row_2"], rows)
 
@@ -3744,7 +3847,7 @@ class MainTest(WebMockTestCase):
             rows = list(main.fetch_markets({"foo": "bar"}, 1))
 
             psycopg2.connect.assert_called_once_with(foo="bar")
-            cursor_mock.execute.assert_called_once_with("SELECT config,user_id FROM market_configs WHERE user_id = %s", 1)
+            cursor_mock.execute.assert_called_once_with("SELECT id,config,user_id FROM market_configs WHERE user_id = %s", 1)
 
             self.assertEqual(["row_1", "row_2"], rows)