Merge branch 'master' of git.hjp.at:hjp/procrusql

This commit is contained in:
Peter J. Holzer 2024-12-12 21:51:30 +01:00 committed by Peter J. Holzer
commit 12dcf05eaf
3 changed files with 64 additions and 14 deletions

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = ProcruSQL name = ProcruSQL
version = 0.0.14 version = 0.0.15
author = Peter J. Holzer author = Peter J. Holzer
author_email = hjp@hjp.at author_email = hjp@hjp.at
description = Make a database fit its description description = Make a database fit its description

View File

@ -154,7 +154,7 @@ class HaveTable(Node):
def __init__(self, name, depends, table, schema="public"): def __init__(self, name, depends, table, schema="public"):
super().__init__(name, depends) super().__init__(name, depends)
self.table = table self.table = table
self.schema = "public" self.schema = schema
def check(self): def check(self):
log_check.info("Checking %s", self.name) log_check.info("Checking %s", self.name)
@ -202,7 +202,7 @@ class HaveColumn(Node):
self.table = table self.table = table
self.column = column self.column = column
self.definition = definition self.definition = definition
self.schema = "public" self.schema = schema
def check(self): def check(self):
log_check.info("Checking %s", self.name) log_check.info("Checking %s", self.name)
@ -356,16 +356,22 @@ def main():
ap = argparse.ArgumentParser() ap = argparse.ArgumentParser()
ap.add_argument("--dbname") ap.add_argument("--dbname")
ap.add_argument("--dbuser") ap.add_argument("--dbuser")
ap.add_argument("--schema", default="public")
ap.add_argument("files", nargs="+") ap.add_argument("files", nargs="+")
args = ap.parse_args() args = ap.parse_args()
db = psycopg2.connect(dbname=args.dbname, user=args.dbuser) db = psycopg2.connect(dbname=args.dbname, user=args.dbuser)
csr = db.cursor()
csr.execute("show search_path")
search_path = csr.fetchone()[0]
search_path = args.schema + ", " + search_path
csr.execute(f"set search_path to {search_path}")
rules = [] rules = []
for f in args.files: for f in args.files:
with open(f) as rf: with open(f) as rf:
text = rf.read() text = rf.read()
ps = parser.ParseState(text) ps = parser.ParseState(text, schema=args.schema)
ps2 = parser.parse_ruleset(ps) ps2 = parser.parse_ruleset(ps)

View File

@ -21,13 +21,14 @@ class Failure:
class ParseState: class ParseState:
def __init__(self, text, position=0): def __init__(self, text, position=0, schema="public"):
self.text = text self.text = text
self.position = position self.position = position
self.schema = schema # the default schema. probably doesn't belong into the parser state
self.child_failure = None self.child_failure = None
def clone(self): def clone(self):
ps = ParseState(self.text, self.position) ps = ParseState(self.text, self.position, self.schema)
return ps return ps
@property @property
@ -42,7 +43,7 @@ class ParseState:
linesbefore = self.text[:position].split("\n") linesbefore = self.text[:position].split("\n")
linesafter = self.text[position:].split("\n") linesafter = self.text[position:].split("\n")
good = "\x1B[40;32m" good = "\x1B[40;32m"
bad = "\x1B[40;31m" bad = "\x1B[40;31;1m"
reset = "\x1B[0m" reset = "\x1B[0m"
s = reset + message + "\n" s = reset + message + "\n"
lines = [] lines = []
@ -99,8 +100,16 @@ def parse_table_rule(ps):
if not ps3: if not ps3:
ps.record_child_failure(ps2, "expected table name") ps.record_child_failure(ps2, "expected table name")
return return
if len(ps3.ast) == 2:
schema_name = ps3.ast[0]
table_name = ps3.ast[1]
elif len(ps3.ast) == 1:
schema_name = ps3.schema
table_name = ps3.ast[0]
else:
assert(False)
ps2.ast = procrusql.HaveTable(rulename(), [], ps3.ast[0]) ps2.ast = procrusql.HaveTable(rulename(), [], ps3.ast[0], schema=schema_name)
ps2.position = ps3.position ps2.position = ps3.position
return ps2 return ps2
@ -120,7 +129,14 @@ def parse_column_rule(ps):
if not ps3: if not ps3:
ps.record_child_failure(ps2, "expected table name") ps.record_child_failure(ps2, "expected table name")
return return
if len(ps3.ast) == 2:
schema_name = ps3.ast[0]
table_name = ps3.ast[1]
elif len(ps3.ast) == 1:
schema_name = ps3.schema
table_name = ps3.ast[0] table_name = ps3.ast[0]
else:
assert(False)
ps2.position = ps3.position ps2.position = ps3.position
ps3 = parse_column_name(ps2) ps3 = parse_column_name(ps2)
@ -137,7 +153,9 @@ def parse_column_rule(ps):
column_definition = ps3.ast[0] column_definition = ps3.ast[0]
ps2.position = ps3.position ps2.position = ps3.position
ps2.ast = procrusql.HaveColumn(rulename(), [], table_name, column_name, column_definition) ps2.ast = procrusql.HaveColumn(
rulename(), [],
table_name, column_name, column_definition, schema=schema_name)
ps2.match_newlines() ps2.match_newlines()
@ -154,7 +172,14 @@ def parse_data_rule(ps):
if not ps3: if not ps3:
ps.record_child_failure(ps2, "expected table name") ps.record_child_failure(ps2, "expected table name")
return return
if len(ps3.ast) == 2:
schema_name = ps3.ast[0]
table_name = ps3.ast[1]
elif len(ps3.ast) == 1:
schema_name = ps3.schema
table_name = ps3.ast[0] table_name = ps3.ast[0]
else:
assert(False)
ps2.position = ps3.position ps2.position = ps3.position
if ps3 := parse_dict(ps2): if ps3 := parse_dict(ps2):
@ -175,7 +200,9 @@ def parse_data_rule(ps):
else: else:
label = rulename() label = rulename()
ps2.ast = procrusql.HaveData(label, [], table_name, key_data, extra_data) ps2.ast = procrusql.HaveData(
label, [],
table_name, key_data, extra_data, schema=schema_name)
ps2.match_newlines() ps2.match_newlines()
elif ps3 := parse_init_query(ps2): elif ps3 := parse_init_query(ps2):
@ -216,7 +243,14 @@ def parse_index_rule(ps):
if not ps3: if not ps3:
ps.record_child_failure(ps2, "expected table name") ps.record_child_failure(ps2, "expected table name")
return return
if len(ps3.ast) == 2:
schema_name = ps3.ast[0]
table_name = ps3.ast[1]
elif len(ps3.ast) == 1:
schema_name = ps3.schema
table_name = ps3.ast[0] table_name = ps3.ast[0]
else:
assert(False)
ps2.position = ps3.position ps2.position = ps3.position
m = ps2.match(r"\s*(using\b|\([\w, ]+\))[^>\n]*") m = ps2.match(r"\s*(using\b|\([\w, ]+\))[^>\n]*")
@ -231,7 +265,10 @@ def parse_index_rule(ps):
else: else:
label = rulename() label = rulename()
ps2.ast = procrusql.HaveIndex(label, [], table_name, index_name, index_type, index_definition) ps2.ast = procrusql.HaveIndex(
label, [],
table_name, index_name,
index_type, index_definition, schema=schema_name)
ps2.match_newlines() ps2.match_newlines()
@ -262,6 +299,12 @@ def parse_table_name(ps):
if ps2.rest[0].isalpha(): if ps2.rest[0].isalpha():
m = ps2.match(r"\w+") # always succeeds since we already checked the first character m = ps2.match(r"\w+") # always succeeds since we already checked the first character
ps2.ast.append(m.group(0)) ps2.ast.append(m.group(0))
if ps2.rest[0] == ".":
m = ps2.match(r"\w+")
if not m:
ps.record_child_failure(ps2, "expected table name after schema")
return
ps2.ast.append(m.group(0))
else: else:
ps.record_child_failure(ps2, "expected table name") ps.record_child_failure(ps2, "expected table name")
return ps2 return ps2
@ -312,6 +355,7 @@ def parse_column_definition(ps):
"json", "jsonb", "json", "jsonb",
"uuid", "uuid",
r"integer\[\]", r"int\[\]", r"bigint\[\]", r"integer\[\]", r"int\[\]", r"bigint\[\]",
"bytea",
), ),
key=lambda x: -len(x) # longest match first key=lambda x: -len(x) # longest match first
) )