Merge branch 'master' of git.hjp.at:hjp/procrusql
This commit is contained in:
commit
12dcf05eaf
|
@ -1,6 +1,6 @@
|
|||
[metadata]
|
||||
name = ProcruSQL
|
||||
version = 0.0.14
|
||||
version = 0.0.15
|
||||
author = Peter J. Holzer
|
||||
author_email = hjp@hjp.at
|
||||
description = Make a database fit its description
|
||||
|
|
|
@ -154,7 +154,7 @@ class HaveTable(Node):
|
|||
def __init__(self, name, depends, table, schema="public"):
|
||||
super().__init__(name, depends)
|
||||
self.table = table
|
||||
self.schema = "public"
|
||||
self.schema = schema
|
||||
|
||||
def check(self):
|
||||
log_check.info("Checking %s", self.name)
|
||||
|
@ -202,7 +202,7 @@ class HaveColumn(Node):
|
|||
self.table = table
|
||||
self.column = column
|
||||
self.definition = definition
|
||||
self.schema = "public"
|
||||
self.schema = schema
|
||||
|
||||
def check(self):
|
||||
log_check.info("Checking %s", self.name)
|
||||
|
@ -356,16 +356,22 @@ def main():
|
|||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--dbname")
|
||||
ap.add_argument("--dbuser")
|
||||
ap.add_argument("--schema", default="public")
|
||||
ap.add_argument("files", nargs="+")
|
||||
args = ap.parse_args()
|
||||
|
||||
db = psycopg2.connect(dbname=args.dbname, user=args.dbuser)
|
||||
csr = db.cursor()
|
||||
csr.execute("show search_path")
|
||||
search_path = csr.fetchone()[0]
|
||||
search_path = args.schema + ", " + search_path
|
||||
csr.execute(f"set search_path to {search_path}")
|
||||
|
||||
rules = []
|
||||
for f in args.files:
|
||||
with open(f) as rf:
|
||||
text = rf.read()
|
||||
ps = parser.ParseState(text)
|
||||
ps = parser.ParseState(text, schema=args.schema)
|
||||
|
||||
ps2 = parser.parse_ruleset(ps)
|
||||
|
||||
|
|
|
@ -21,13 +21,14 @@ class Failure:
|
|||
|
||||
class ParseState:
|
||||
|
||||
def __init__(self, text, position=0):
|
||||
def __init__(self, text, position=0, schema="public"):
|
||||
self.text = text
|
||||
self.position = position
|
||||
self.schema = schema # the default schema. probably doesn't belong into the parser state
|
||||
self.child_failure = None
|
||||
|
||||
def clone(self):
|
||||
ps = ParseState(self.text, self.position)
|
||||
ps = ParseState(self.text, self.position, self.schema)
|
||||
return ps
|
||||
|
||||
@property
|
||||
|
@ -42,7 +43,7 @@ class ParseState:
|
|||
linesbefore = self.text[:position].split("\n")
|
||||
linesafter = self.text[position:].split("\n")
|
||||
good = "\x1B[40;32m"
|
||||
bad = "\x1B[40;31m"
|
||||
bad = "\x1B[40;31;1m"
|
||||
reset = "\x1B[0m"
|
||||
s = reset + message + "\n"
|
||||
lines = []
|
||||
|
@ -99,8 +100,16 @@ def parse_table_rule(ps):
|
|||
if not ps3:
|
||||
ps.record_child_failure(ps2, "expected table name")
|
||||
return
|
||||
if len(ps3.ast) == 2:
|
||||
schema_name = ps3.ast[0]
|
||||
table_name = ps3.ast[1]
|
||||
elif len(ps3.ast) == 1:
|
||||
schema_name = ps3.schema
|
||||
table_name = ps3.ast[0]
|
||||
else:
|
||||
assert(False)
|
||||
|
||||
ps2.ast = procrusql.HaveTable(rulename(), [], ps3.ast[0])
|
||||
ps2.ast = procrusql.HaveTable(rulename(), [], ps3.ast[0], schema=schema_name)
|
||||
ps2.position = ps3.position
|
||||
return ps2
|
||||
|
||||
|
@ -120,7 +129,14 @@ def parse_column_rule(ps):
|
|||
if not ps3:
|
||||
ps.record_child_failure(ps2, "expected table name")
|
||||
return
|
||||
table_name = ps3.ast[0]
|
||||
if len(ps3.ast) == 2:
|
||||
schema_name = ps3.ast[0]
|
||||
table_name = ps3.ast[1]
|
||||
elif len(ps3.ast) == 1:
|
||||
schema_name = ps3.schema
|
||||
table_name = ps3.ast[0]
|
||||
else:
|
||||
assert(False)
|
||||
ps2.position = ps3.position
|
||||
|
||||
ps3 = parse_column_name(ps2)
|
||||
|
@ -137,7 +153,9 @@ def parse_column_rule(ps):
|
|||
column_definition = ps3.ast[0]
|
||||
ps2.position = ps3.position
|
||||
|
||||
ps2.ast = procrusql.HaveColumn(rulename(), [], table_name, column_name, column_definition)
|
||||
ps2.ast = procrusql.HaveColumn(
|
||||
rulename(), [],
|
||||
table_name, column_name, column_definition, schema=schema_name)
|
||||
|
||||
ps2.match_newlines()
|
||||
|
||||
|
@ -154,7 +172,14 @@ def parse_data_rule(ps):
|
|||
if not ps3:
|
||||
ps.record_child_failure(ps2, "expected table name")
|
||||
return
|
||||
table_name = ps3.ast[0]
|
||||
if len(ps3.ast) == 2:
|
||||
schema_name = ps3.ast[0]
|
||||
table_name = ps3.ast[1]
|
||||
elif len(ps3.ast) == 1:
|
||||
schema_name = ps3.schema
|
||||
table_name = ps3.ast[0]
|
||||
else:
|
||||
assert(False)
|
||||
ps2.position = ps3.position
|
||||
|
||||
if ps3 := parse_dict(ps2):
|
||||
|
@ -175,7 +200,9 @@ def parse_data_rule(ps):
|
|||
else:
|
||||
label = rulename()
|
||||
|
||||
ps2.ast = procrusql.HaveData(label, [], table_name, key_data, extra_data)
|
||||
ps2.ast = procrusql.HaveData(
|
||||
label, [],
|
||||
table_name, key_data, extra_data, schema=schema_name)
|
||||
|
||||
ps2.match_newlines()
|
||||
elif ps3 := parse_init_query(ps2):
|
||||
|
@ -216,7 +243,14 @@ def parse_index_rule(ps):
|
|||
if not ps3:
|
||||
ps.record_child_failure(ps2, "expected table name")
|
||||
return
|
||||
table_name = ps3.ast[0]
|
||||
if len(ps3.ast) == 2:
|
||||
schema_name = ps3.ast[0]
|
||||
table_name = ps3.ast[1]
|
||||
elif len(ps3.ast) == 1:
|
||||
schema_name = ps3.schema
|
||||
table_name = ps3.ast[0]
|
||||
else:
|
||||
assert(False)
|
||||
ps2.position = ps3.position
|
||||
|
||||
m = ps2.match(r"\s*(using\b|\([\w, ]+\))[^>\n]*")
|
||||
|
@ -231,7 +265,10 @@ def parse_index_rule(ps):
|
|||
else:
|
||||
label = rulename()
|
||||
|
||||
ps2.ast = procrusql.HaveIndex(label, [], table_name, index_name, index_type, index_definition)
|
||||
ps2.ast = procrusql.HaveIndex(
|
||||
label, [],
|
||||
table_name, index_name,
|
||||
index_type, index_definition, schema=schema_name)
|
||||
|
||||
ps2.match_newlines()
|
||||
|
||||
|
@ -262,6 +299,12 @@ def parse_table_name(ps):
|
|||
if ps2.rest[0].isalpha():
|
||||
m = ps2.match(r"\w+") # always succeeds since we already checked the first character
|
||||
ps2.ast.append(m.group(0))
|
||||
if ps2.rest[0] == ".":
|
||||
m = ps2.match(r"\w+")
|
||||
if not m:
|
||||
ps.record_child_failure(ps2, "expected table name after schema")
|
||||
return
|
||||
ps2.ast.append(m.group(0))
|
||||
else:
|
||||
ps.record_child_failure(ps2, "expected table name")
|
||||
return ps2
|
||||
|
@ -312,6 +355,7 @@ def parse_column_definition(ps):
|
|||
"json", "jsonb",
|
||||
"uuid",
|
||||
r"integer\[\]", r"int\[\]", r"bigint\[\]",
|
||||
"bytea",
|
||||
),
|
||||
key=lambda x: -len(x) # longest match first
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue