From 2cd66bdf50f7869403afc9df79fe716635a74df3 Mon Sep 17 00:00:00 2001 From: "Peter J. Holzer" Date: Tue, 2 Aug 2022 14:49:07 +0200 Subject: [PATCH] Initialize table from SQL query --- setup.cfg | 2 +- src/procrusql/__init__.py | 32 ++++++++++++++++++++ src/procrusql/parser.py | 61 +++++++++++++++++++++++++-------------- 3 files changed, 73 insertions(+), 22 deletions(-) diff --git a/setup.cfg b/setup.cfg index a622587..e776854 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = ProcruSQL -version = 0.0.6 +version = 0.0.7 author = Peter J. Holzer author_email = hjp@hjp.at description = Make a database fit its description diff --git a/src/procrusql/__init__.py b/src/procrusql/__init__.py index de1ccee..690eee0 100644 --- a/src/procrusql/__init__.py +++ b/src/procrusql/__init__.py @@ -115,6 +115,36 @@ class HaveData(Node): # exception. Success with 0 rows should not happen. raise RuntimeError("Unreachable code reached") +class HaveInit(Node): + def __init__(self, name, depends, table, query): + super().__init__(name, depends) + self.table = table + self.query = query + + def check(self): + log_check.info("Checking %s", self.name) + + for w in want: + if isinstance(w, HaveTable) and w.table == self.table: + if not w.ok: + log_check.error("%s not yet ok", w.name) + raise RuntimeError(f"Cannot insert into table {w.table} from {w.name} which is not yet ok. Please add a dependency") + if w.new: + log_action.info("Executing %s", self.query) + csr = db.cursor(cursor_factory=extras.DictCursor) + csr.execute(self.query) + self.result = csr.fetchall() + log_action.info("Got %d rows", len(self.result)) + self.set_order() + else: + log_check.info("Table %s already exists", w.table) + self.ok = True + log_state.info("%s is now ok", self.name) + break + else: + raise RuntimeError(f"Cannot find a rule which creates table {self.table}") + + class HaveTable(Node): def __init__(self, name, depends, table, schema="public"): @@ -136,6 +166,7 @@ class HaveTable(Node): # Table exists, all ok self.set_order() self.ok = True + self.new = False log_state.info("%s is now ok", self.name) return if len(r) > 1: @@ -151,6 +182,7 @@ class HaveTable(Node): csr.execute(q) self.set_order() self.ok = True + self.new = True log_state.info("%s is now ok", self.name) pass diff --git a/src/procrusql/parser.py b/src/procrusql/parser.py index 5decaff..87a59f2 100755 --- a/src/procrusql/parser.py +++ b/src/procrusql/parser.py @@ -156,30 +156,39 @@ def parse_data_rule(ps): table_name = ps3.ast[0] ps2.position = ps3.position - ps3 = parse_dict(ps2) - if not ps3: - ps.record_child_failure(ps2, "expected key data definition") - return - key_data = ps3.ast - ps2.position = ps3.position - - ps3 = parse_dict(ps2) - if not ps3: - ps.record_child_failure(ps2, "expected extra data definition") - return - extra_data = ps3.ast - ps2.position = ps3.position - - ps3 = parse_label(ps2) - if ps3: - label = ps3.ast + if ps3 := parse_dict(ps2): + key_data = ps3.ast ps2.position = ps3.position - else: + + ps3 = parse_dict(ps2) + if not ps3: + ps.record_child_failure(ps2, "expected extra data definition") + return + extra_data = ps3.ast + ps2.position = ps3.position + + ps3 = parse_label(ps2) + if ps3: + label = ps3.ast + ps2.position = ps3.position + else: + label = rulename() + + ps2.ast = procrusql.HaveData(label, [], table_name, key_data, extra_data) + + ps2.match_newlines() + elif ps3 := parse_init_query(ps2): + # We have a bit of a problem here: The query extends to the end of the + # line so there is no room for a label. I really don't want to parse + # SQL here. label = rulename() + ps2.position = ps3.position + ps2.ast = procrusql.HaveInit(label, [], table_name, ps3.ast) - ps2.ast = procrusql.HaveData(label, [], table_name, key_data, extra_data) - - ps2.match_newlines() + ps2.match_newlines() + else: + ps.record_child_failure(ps2, "expected key data definition or insert query") + return return ps2 @@ -347,6 +356,16 @@ def parse_dict(ps): ps2.ast = d return ps2 +def parse_init_query(ps): + ps2 = ps.clone() + ps2.skip_whitespace_and_comments() + if m := ps2.match(r'(?i:with|insert)\b.*'): + ps2.ast = m.group(0) + return ps2 + else: + ps.record_child_failure(ps2, "expected insert query") + return + def parse_label(ps): ps2 = ps.clone() if m := ps2.match(r"\s*>>\s*(\w+)"):