From 7970835c8b8af9a3d13dec3b193cfce90519ac83 Mon Sep 17 00:00:00 2001 From: "Peter J. Holzer" Date: Thu, 12 Dec 2024 21:46:14 +0100 Subject: [PATCH 1/3] Add schema support --- setup.cfg | 2 +- src/procrusql/__init__.py | 24 ++++++++++----- src/procrusql/parser.py | 61 +++++++++++++++++++++++++++++++++------ 3 files changed, 70 insertions(+), 17 deletions(-) diff --git a/setup.cfg b/setup.cfg index 19d9bcb..9b4797a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = ProcruSQL -version = 0.0.14 +version = 0.0.15 author = Peter J. Holzer author_email = hjp@hjp.at description = Make a database fit its description diff --git a/src/procrusql/__init__.py b/src/procrusql/__init__.py index a0994f7..d4a83ef 100644 --- a/src/procrusql/__init__.py +++ b/src/procrusql/__init__.py @@ -48,11 +48,12 @@ class Node: self.order = order class HaveData(Node): - def __init__(self, name, depends, table, key, extra): + def __init__(self, name, depends, table, key, extra, schema="public"): super().__init__(name, depends) self.table = table self.key = key self.extra = extra + self.schema = schema def check(self): log_check.info("Checking %s", self.name) @@ -63,8 +64,9 @@ class HaveData(Node): ] key_check = sql.SQL(" and ").join(key_checks) q = sql.SQL( - "select * from {table} where {key_check}" + "select * from {schema}.{table} where {key_check}" ).format( + schema=sql.Identifier(self.schema), table=sql.Identifier(self.table), key_check=key_check ) @@ -78,8 +80,9 @@ class HaveData(Node): if self.result[0][c] != self.extra[c]: log_action.info("Updating %s: %s <- %s", key_values, c, self.extra[c]) q = sql.SQL( - "update {table} set {column}={placeholder} where {key_check}" + "update {schema}.{table} set {column}={placeholder} where {key_check}" ).format( + schema=sql.Identifier(self.schema), table=sql.Identifier(self.table), column=sql.Identifier(c), placeholder=sql.Placeholder(), @@ -96,8 +99,9 @@ class HaveData(Node): columns = list(self.key.keys()) + list(self.extra.keys()) values = key_values + extra_values q = sql.SQL( - "insert into {table}({columns}) values({placeholders}) returning *" + "insert into {schema}.{table}({columns}) values({placeholders}) returning *" ).format( + schema=sql.Identifier(self.schema), table=sql.Identifier(self.table), columns=sql.SQL(", ").join([sql.Identifier(x) for x in columns]), placeholders=sql.SQL(", ").join([sql.Placeholder() for x in columns]), @@ -150,7 +154,7 @@ class HaveTable(Node): def __init__(self, name, depends, table, schema="public"): super().__init__(name, depends) self.table = table - self.schema = "public" + self.schema = schema def check(self): log_check.info("Checking %s", self.name) @@ -198,7 +202,7 @@ class HaveColumn(Node): self.table = table self.column = column self.definition = definition - self.schema = "public" + self.schema = schema def check(self): log_check.info("Checking %s", self.name) @@ -352,16 +356,22 @@ def main(): ap = argparse.ArgumentParser() ap.add_argument("--dbname") ap.add_argument("--dbuser") + ap.add_argument("--schema", default="public") ap.add_argument("files", nargs="+") args = ap.parse_args() db = psycopg2.connect(dbname=args.dbname, user=args.dbuser) + csr = db.cursor() + csr.execute("show search_path") + search_path = csr.fetchone()[0] + search_path = args.schema + ", " + search_path + csr.execute(f"set search_path to {search_path}") rules = [] for f in args.files: with open(f) as rf: text = rf.read() - ps = parser.ParseState(text) + ps = parser.ParseState(text, schema=args.schema) ps2 = parser.parse_ruleset(ps) diff --git a/src/procrusql/parser.py b/src/procrusql/parser.py index 7e2c50e..d15e42b 100755 --- a/src/procrusql/parser.py +++ b/src/procrusql/parser.py @@ -21,13 +21,14 @@ class Failure: class ParseState: - def __init__(self, text, position=0): + def __init__(self, text, position=0, schema="public"): self.text = text self.position = position + self.schema = schema # the default schema. probably doesn't belong into the parser state self.child_failure = None def clone(self): - ps = ParseState(self.text, self.position) + ps = ParseState(self.text, self.position, self.schema) return ps @property @@ -98,8 +99,16 @@ def parse_table_rule(ps): if not ps3: ps.record_child_failure(ps2, "expected table name") return + if len(ps3.ast) == 2: + schema_name = ps3.ast[0] + table_name = ps3.ast[1] + elif len(ps3.ast) == 1: + schema_name = ps3.schema + table_name = ps3.ast[0] + else: + assert(False) - ps2.ast = procrusql.HaveTable(rulename(), [], ps3.ast[0]) + ps2.ast = procrusql.HaveTable(rulename(), [], ps3.ast[0], schema=schema_name) ps2.position = ps3.position return ps2 @@ -119,7 +128,14 @@ def parse_column_rule(ps): if not ps3: ps.record_child_failure(ps2, "expected table name") return - table_name = ps3.ast[0] + if len(ps3.ast) == 2: + schema_name = ps3.ast[0] + table_name = ps3.ast[1] + elif len(ps3.ast) == 1: + schema_name = ps3.schema + table_name = ps3.ast[0] + else: + assert(False) ps2.position = ps3.position ps3 = parse_column_name(ps2) @@ -136,7 +152,9 @@ def parse_column_rule(ps): column_definition = ps3.ast[0] ps2.position = ps3.position - ps2.ast = procrusql.HaveColumn(rulename(), [], table_name, column_name, column_definition) + ps2.ast = procrusql.HaveColumn( + rulename(), [], + table_name, column_name, column_definition, schema=schema_name) ps2.match_newlines() @@ -153,7 +171,14 @@ def parse_data_rule(ps): if not ps3: ps.record_child_failure(ps2, "expected table name") return - table_name = ps3.ast[0] + if len(ps3.ast) == 2: + schema_name = ps3.ast[0] + table_name = ps3.ast[1] + elif len(ps3.ast) == 1: + schema_name = ps3.schema + table_name = ps3.ast[0] + else: + assert(False) ps2.position = ps3.position if ps3 := parse_dict(ps2): @@ -174,7 +199,9 @@ def parse_data_rule(ps): else: label = rulename() - ps2.ast = procrusql.HaveData(label, [], table_name, key_data, extra_data) + ps2.ast = procrusql.HaveData( + label, [], + table_name, key_data, extra_data, schema=schema_name) ps2.match_newlines() elif ps3 := parse_init_query(ps2): @@ -215,7 +242,14 @@ def parse_index_rule(ps): if not ps3: ps.record_child_failure(ps2, "expected table name") return - table_name = ps3.ast[0] + if len(ps3.ast) == 2: + schema_name = ps3.ast[0] + table_name = ps3.ast[1] + elif len(ps3.ast) == 1: + schema_name = ps3.schema + table_name = ps3.ast[0] + else: + assert(False) ps2.position = ps3.position m = ps2.match(r"\s*(using\b|\([\w, ]+\))[^>\n]*") @@ -230,7 +264,10 @@ def parse_index_rule(ps): else: label = rulename() - ps2.ast = procrusql.HaveIndex(label, [], table_name, index_name, index_type, index_definition) + ps2.ast = procrusql.HaveIndex( + label, [], + table_name, index_name, + index_type, index_definition, schema=schema_name) ps2.match_newlines() @@ -247,6 +284,12 @@ def parse_table_name(ps): if ps2.rest[0].isalpha(): m = ps2.match(r"\w+") # always succeeds since we already checked the first character ps2.ast.append(m.group(0)) + if ps2.rest[0] == ".": + m = ps2.match(r"\w+") + if not m: + ps.record_child_failure(ps2, "expected table name after schema") + return + ps2.ast.append(m.group(0)) else: ps.record_child_failure(ps2, "expected table name") return ps2 From 363cdcede12391128e30d19316a9655f17468281 Mon Sep 17 00:00:00 2001 From: "Peter J. Holzer" Date: Thu, 12 Dec 2024 21:48:06 +0100 Subject: [PATCH 2/3] Make error messages more readable --- src/procrusql/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/procrusql/parser.py b/src/procrusql/parser.py index d15e42b..e80b9f8 100755 --- a/src/procrusql/parser.py +++ b/src/procrusql/parser.py @@ -43,7 +43,7 @@ class ParseState: linesbefore = self.text[:position].split("\n") linesafter = self.text[position:].split("\n") good = "\x1B[40;32m" - bad = "\x1B[40;31m" + bad = "\x1B[40;31;1m" reset = "\x1B[0m" s = reset + message + "\n" lines = [] From 8f93c2e2da7c22f84bb6f0d3511cef535ea06be6 Mon Sep 17 00:00:00 2001 From: "Peter J. Holzer" Date: Thu, 12 Dec 2024 21:48:31 +0100 Subject: [PATCH 3/3] Add type bytea --- src/procrusql/parser.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/procrusql/parser.py b/src/procrusql/parser.py index e80b9f8..e5e1265 100755 --- a/src/procrusql/parser.py +++ b/src/procrusql/parser.py @@ -340,6 +340,7 @@ def parse_column_definition(ps): "json", "jsonb", "uuid", r"integer\[\]", r"int\[\]", r"bigint\[\]", + "bytea", ), key=lambda x: -len(x) # longest match first )