Add schema support
This commit is contained in:
parent
f6c64a50ab
commit
7970835c8b
|
@ -1,6 +1,6 @@
|
||||||
[metadata]
|
[metadata]
|
||||||
name = ProcruSQL
|
name = ProcruSQL
|
||||||
version = 0.0.14
|
version = 0.0.15
|
||||||
author = Peter J. Holzer
|
author = Peter J. Holzer
|
||||||
author_email = hjp@hjp.at
|
author_email = hjp@hjp.at
|
||||||
description = Make a database fit its description
|
description = Make a database fit its description
|
||||||
|
|
|
@ -48,11 +48,12 @@ class Node:
|
||||||
self.order = order
|
self.order = order
|
||||||
|
|
||||||
class HaveData(Node):
|
class HaveData(Node):
|
||||||
def __init__(self, name, depends, table, key, extra):
|
def __init__(self, name, depends, table, key, extra, schema="public"):
|
||||||
super().__init__(name, depends)
|
super().__init__(name, depends)
|
||||||
self.table = table
|
self.table = table
|
||||||
self.key = key
|
self.key = key
|
||||||
self.extra = extra
|
self.extra = extra
|
||||||
|
self.schema = schema
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
log_check.info("Checking %s", self.name)
|
log_check.info("Checking %s", self.name)
|
||||||
|
@ -63,8 +64,9 @@ class HaveData(Node):
|
||||||
]
|
]
|
||||||
key_check = sql.SQL(" and ").join(key_checks)
|
key_check = sql.SQL(" and ").join(key_checks)
|
||||||
q = sql.SQL(
|
q = sql.SQL(
|
||||||
"select * from {table} where {key_check}"
|
"select * from {schema}.{table} where {key_check}"
|
||||||
).format(
|
).format(
|
||||||
|
schema=sql.Identifier(self.schema),
|
||||||
table=sql.Identifier(self.table),
|
table=sql.Identifier(self.table),
|
||||||
key_check=key_check
|
key_check=key_check
|
||||||
)
|
)
|
||||||
|
@ -78,8 +80,9 @@ class HaveData(Node):
|
||||||
if self.result[0][c] != self.extra[c]:
|
if self.result[0][c] != self.extra[c]:
|
||||||
log_action.info("Updating %s: %s <- %s", key_values, c, self.extra[c])
|
log_action.info("Updating %s: %s <- %s", key_values, c, self.extra[c])
|
||||||
q = sql.SQL(
|
q = sql.SQL(
|
||||||
"update {table} set {column}={placeholder} where {key_check}"
|
"update {schema}.{table} set {column}={placeholder} where {key_check}"
|
||||||
).format(
|
).format(
|
||||||
|
schema=sql.Identifier(self.schema),
|
||||||
table=sql.Identifier(self.table),
|
table=sql.Identifier(self.table),
|
||||||
column=sql.Identifier(c),
|
column=sql.Identifier(c),
|
||||||
placeholder=sql.Placeholder(),
|
placeholder=sql.Placeholder(),
|
||||||
|
@ -96,8 +99,9 @@ class HaveData(Node):
|
||||||
columns = list(self.key.keys()) + list(self.extra.keys())
|
columns = list(self.key.keys()) + list(self.extra.keys())
|
||||||
values = key_values + extra_values
|
values = key_values + extra_values
|
||||||
q = sql.SQL(
|
q = sql.SQL(
|
||||||
"insert into {table}({columns}) values({placeholders}) returning *"
|
"insert into {schema}.{table}({columns}) values({placeholders}) returning *"
|
||||||
).format(
|
).format(
|
||||||
|
schema=sql.Identifier(self.schema),
|
||||||
table=sql.Identifier(self.table),
|
table=sql.Identifier(self.table),
|
||||||
columns=sql.SQL(", ").join([sql.Identifier(x) for x in columns]),
|
columns=sql.SQL(", ").join([sql.Identifier(x) for x in columns]),
|
||||||
placeholders=sql.SQL(", ").join([sql.Placeholder() for x in columns]),
|
placeholders=sql.SQL(", ").join([sql.Placeholder() for x in columns]),
|
||||||
|
@ -150,7 +154,7 @@ class HaveTable(Node):
|
||||||
def __init__(self, name, depends, table, schema="public"):
|
def __init__(self, name, depends, table, schema="public"):
|
||||||
super().__init__(name, depends)
|
super().__init__(name, depends)
|
||||||
self.table = table
|
self.table = table
|
||||||
self.schema = "public"
|
self.schema = schema
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
log_check.info("Checking %s", self.name)
|
log_check.info("Checking %s", self.name)
|
||||||
|
@ -198,7 +202,7 @@ class HaveColumn(Node):
|
||||||
self.table = table
|
self.table = table
|
||||||
self.column = column
|
self.column = column
|
||||||
self.definition = definition
|
self.definition = definition
|
||||||
self.schema = "public"
|
self.schema = schema
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
log_check.info("Checking %s", self.name)
|
log_check.info("Checking %s", self.name)
|
||||||
|
@ -352,16 +356,22 @@ def main():
|
||||||
ap = argparse.ArgumentParser()
|
ap = argparse.ArgumentParser()
|
||||||
ap.add_argument("--dbname")
|
ap.add_argument("--dbname")
|
||||||
ap.add_argument("--dbuser")
|
ap.add_argument("--dbuser")
|
||||||
|
ap.add_argument("--schema", default="public")
|
||||||
ap.add_argument("files", nargs="+")
|
ap.add_argument("files", nargs="+")
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
|
|
||||||
db = psycopg2.connect(dbname=args.dbname, user=args.dbuser)
|
db = psycopg2.connect(dbname=args.dbname, user=args.dbuser)
|
||||||
|
csr = db.cursor()
|
||||||
|
csr.execute("show search_path")
|
||||||
|
search_path = csr.fetchone()[0]
|
||||||
|
search_path = args.schema + ", " + search_path
|
||||||
|
csr.execute(f"set search_path to {search_path}")
|
||||||
|
|
||||||
rules = []
|
rules = []
|
||||||
for f in args.files:
|
for f in args.files:
|
||||||
with open(f) as rf:
|
with open(f) as rf:
|
||||||
text = rf.read()
|
text = rf.read()
|
||||||
ps = parser.ParseState(text)
|
ps = parser.ParseState(text, schema=args.schema)
|
||||||
|
|
||||||
ps2 = parser.parse_ruleset(ps)
|
ps2 = parser.parse_ruleset(ps)
|
||||||
|
|
||||||
|
|
|
@ -21,13 +21,14 @@ class Failure:
|
||||||
|
|
||||||
class ParseState:
|
class ParseState:
|
||||||
|
|
||||||
def __init__(self, text, position=0):
|
def __init__(self, text, position=0, schema="public"):
|
||||||
self.text = text
|
self.text = text
|
||||||
self.position = position
|
self.position = position
|
||||||
|
self.schema = schema # the default schema. probably doesn't belong into the parser state
|
||||||
self.child_failure = None
|
self.child_failure = None
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
ps = ParseState(self.text, self.position)
|
ps = ParseState(self.text, self.position, self.schema)
|
||||||
return ps
|
return ps
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -98,8 +99,16 @@ def parse_table_rule(ps):
|
||||||
if not ps3:
|
if not ps3:
|
||||||
ps.record_child_failure(ps2, "expected table name")
|
ps.record_child_failure(ps2, "expected table name")
|
||||||
return
|
return
|
||||||
|
if len(ps3.ast) == 2:
|
||||||
|
schema_name = ps3.ast[0]
|
||||||
|
table_name = ps3.ast[1]
|
||||||
|
elif len(ps3.ast) == 1:
|
||||||
|
schema_name = ps3.schema
|
||||||
|
table_name = ps3.ast[0]
|
||||||
|
else:
|
||||||
|
assert(False)
|
||||||
|
|
||||||
ps2.ast = procrusql.HaveTable(rulename(), [], ps3.ast[0])
|
ps2.ast = procrusql.HaveTable(rulename(), [], ps3.ast[0], schema=schema_name)
|
||||||
ps2.position = ps3.position
|
ps2.position = ps3.position
|
||||||
return ps2
|
return ps2
|
||||||
|
|
||||||
|
@ -119,7 +128,14 @@ def parse_column_rule(ps):
|
||||||
if not ps3:
|
if not ps3:
|
||||||
ps.record_child_failure(ps2, "expected table name")
|
ps.record_child_failure(ps2, "expected table name")
|
||||||
return
|
return
|
||||||
table_name = ps3.ast[0]
|
if len(ps3.ast) == 2:
|
||||||
|
schema_name = ps3.ast[0]
|
||||||
|
table_name = ps3.ast[1]
|
||||||
|
elif len(ps3.ast) == 1:
|
||||||
|
schema_name = ps3.schema
|
||||||
|
table_name = ps3.ast[0]
|
||||||
|
else:
|
||||||
|
assert(False)
|
||||||
ps2.position = ps3.position
|
ps2.position = ps3.position
|
||||||
|
|
||||||
ps3 = parse_column_name(ps2)
|
ps3 = parse_column_name(ps2)
|
||||||
|
@ -136,7 +152,9 @@ def parse_column_rule(ps):
|
||||||
column_definition = ps3.ast[0]
|
column_definition = ps3.ast[0]
|
||||||
ps2.position = ps3.position
|
ps2.position = ps3.position
|
||||||
|
|
||||||
ps2.ast = procrusql.HaveColumn(rulename(), [], table_name, column_name, column_definition)
|
ps2.ast = procrusql.HaveColumn(
|
||||||
|
rulename(), [],
|
||||||
|
table_name, column_name, column_definition, schema=schema_name)
|
||||||
|
|
||||||
ps2.match_newlines()
|
ps2.match_newlines()
|
||||||
|
|
||||||
|
@ -153,7 +171,14 @@ def parse_data_rule(ps):
|
||||||
if not ps3:
|
if not ps3:
|
||||||
ps.record_child_failure(ps2, "expected table name")
|
ps.record_child_failure(ps2, "expected table name")
|
||||||
return
|
return
|
||||||
table_name = ps3.ast[0]
|
if len(ps3.ast) == 2:
|
||||||
|
schema_name = ps3.ast[0]
|
||||||
|
table_name = ps3.ast[1]
|
||||||
|
elif len(ps3.ast) == 1:
|
||||||
|
schema_name = ps3.schema
|
||||||
|
table_name = ps3.ast[0]
|
||||||
|
else:
|
||||||
|
assert(False)
|
||||||
ps2.position = ps3.position
|
ps2.position = ps3.position
|
||||||
|
|
||||||
if ps3 := parse_dict(ps2):
|
if ps3 := parse_dict(ps2):
|
||||||
|
@ -174,7 +199,9 @@ def parse_data_rule(ps):
|
||||||
else:
|
else:
|
||||||
label = rulename()
|
label = rulename()
|
||||||
|
|
||||||
ps2.ast = procrusql.HaveData(label, [], table_name, key_data, extra_data)
|
ps2.ast = procrusql.HaveData(
|
||||||
|
label, [],
|
||||||
|
table_name, key_data, extra_data, schema=schema_name)
|
||||||
|
|
||||||
ps2.match_newlines()
|
ps2.match_newlines()
|
||||||
elif ps3 := parse_init_query(ps2):
|
elif ps3 := parse_init_query(ps2):
|
||||||
|
@ -215,7 +242,14 @@ def parse_index_rule(ps):
|
||||||
if not ps3:
|
if not ps3:
|
||||||
ps.record_child_failure(ps2, "expected table name")
|
ps.record_child_failure(ps2, "expected table name")
|
||||||
return
|
return
|
||||||
table_name = ps3.ast[0]
|
if len(ps3.ast) == 2:
|
||||||
|
schema_name = ps3.ast[0]
|
||||||
|
table_name = ps3.ast[1]
|
||||||
|
elif len(ps3.ast) == 1:
|
||||||
|
schema_name = ps3.schema
|
||||||
|
table_name = ps3.ast[0]
|
||||||
|
else:
|
||||||
|
assert(False)
|
||||||
ps2.position = ps3.position
|
ps2.position = ps3.position
|
||||||
|
|
||||||
m = ps2.match(r"\s*(using\b|\([\w, ]+\))[^>\n]*")
|
m = ps2.match(r"\s*(using\b|\([\w, ]+\))[^>\n]*")
|
||||||
|
@ -230,7 +264,10 @@ def parse_index_rule(ps):
|
||||||
else:
|
else:
|
||||||
label = rulename()
|
label = rulename()
|
||||||
|
|
||||||
ps2.ast = procrusql.HaveIndex(label, [], table_name, index_name, index_type, index_definition)
|
ps2.ast = procrusql.HaveIndex(
|
||||||
|
label, [],
|
||||||
|
table_name, index_name,
|
||||||
|
index_type, index_definition, schema=schema_name)
|
||||||
|
|
||||||
ps2.match_newlines()
|
ps2.match_newlines()
|
||||||
|
|
||||||
|
@ -247,6 +284,12 @@ def parse_table_name(ps):
|
||||||
if ps2.rest[0].isalpha():
|
if ps2.rest[0].isalpha():
|
||||||
m = ps2.match(r"\w+") # always succeeds since we already checked the first character
|
m = ps2.match(r"\w+") # always succeeds since we already checked the first character
|
||||||
ps2.ast.append(m.group(0))
|
ps2.ast.append(m.group(0))
|
||||||
|
if ps2.rest[0] == ".":
|
||||||
|
m = ps2.match(r"\w+")
|
||||||
|
if not m:
|
||||||
|
ps.record_child_failure(ps2, "expected table name after schema")
|
||||||
|
return
|
||||||
|
ps2.ast.append(m.group(0))
|
||||||
else:
|
else:
|
||||||
ps.record_child_failure(ps2, "expected table name")
|
ps.record_child_failure(ps2, "expected table name")
|
||||||
return ps2
|
return ps2
|
||||||
|
|
Loading…
Reference in New Issue