Compare commits

..

No commits in common. "master" and "v0.0.11" have entirely different histories.

3 changed files with 27 additions and 103 deletions

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = ProcruSQL name = ProcruSQL
version = 0.0.16 version = 0.0.11
author = Peter J. Holzer author = Peter J. Holzer
author_email = hjp@hjp.at author_email = hjp@hjp.at
description = Make a database fit its description description = Make a database fit its description
@ -9,7 +9,6 @@ long_description_content_type = text/markdown
url = https://git.hjp.at:3000/hjp/procrusql url = https://git.hjp.at:3000/hjp/procrusql
project_urls = project_urls =
Bug Tracker = https://git.hjp.at:3000/hjp/procrusql/issues Bug Tracker = https://git.hjp.at:3000/hjp/procrusql/issues
Repository = https://git.hjp.at:3000/hjp/procrusql
classifiers = classifiers =
Programming Language :: Python :: 3 Programming Language :: Python :: 3
License :: OSI Approved :: MIT License License :: OSI Approved :: MIT License

View File

@ -48,12 +48,11 @@ class Node:
self.order = order self.order = order
class HaveData(Node): class HaveData(Node):
def __init__(self, name, depends, table, key, extra, schema="public"): def __init__(self, name, depends, table, key, extra):
super().__init__(name, depends) super().__init__(name, depends)
self.table = table self.table = table
self.key = key self.key = key
self.extra = extra self.extra = extra
self.schema = schema
def check(self): def check(self):
log_check.info("Checking %s", self.name) log_check.info("Checking %s", self.name)
@ -64,9 +63,8 @@ class HaveData(Node):
] ]
key_check = sql.SQL(" and ").join(key_checks) key_check = sql.SQL(" and ").join(key_checks)
q = sql.SQL( q = sql.SQL(
"select * from {schema}.{table} where {key_check}" "select * from {table} where {key_check}"
).format( ).format(
schema=sql.Identifier(self.schema),
table=sql.Identifier(self.table), table=sql.Identifier(self.table),
key_check=key_check key_check=key_check
) )
@ -80,9 +78,8 @@ class HaveData(Node):
if self.result[0][c] != self.extra[c]: if self.result[0][c] != self.extra[c]:
log_action.info("Updating %s: %s <- %s", key_values, c, self.extra[c]) log_action.info("Updating %s: %s <- %s", key_values, c, self.extra[c])
q = sql.SQL( q = sql.SQL(
"update {schema}.{table} set {column}={placeholder} where {key_check}" "update {table} set {column}={placeholder} where {key_check}"
).format( ).format(
schema=sql.Identifier(self.schema),
table=sql.Identifier(self.table), table=sql.Identifier(self.table),
column=sql.Identifier(c), column=sql.Identifier(c),
placeholder=sql.Placeholder(), placeholder=sql.Placeholder(),
@ -99,9 +96,8 @@ class HaveData(Node):
columns = list(self.key.keys()) + list(self.extra.keys()) columns = list(self.key.keys()) + list(self.extra.keys())
values = key_values + extra_values values = key_values + extra_values
q = sql.SQL( q = sql.SQL(
"insert into {schema}.{table}({columns}) values({placeholders}) returning *" "insert into {table}({columns}) values({placeholders}) returning *"
).format( ).format(
schema=sql.Identifier(self.schema),
table=sql.Identifier(self.table), table=sql.Identifier(self.table),
columns=sql.SQL(", ").join([sql.Identifier(x) for x in columns]), columns=sql.SQL(", ").join([sql.Identifier(x) for x in columns]),
placeholders=sql.SQL(", ").join([sql.Placeholder() for x in columns]), placeholders=sql.SQL(", ").join([sql.Placeholder() for x in columns]),
@ -154,7 +150,7 @@ class HaveTable(Node):
def __init__(self, name, depends, table, schema="public"): def __init__(self, name, depends, table, schema="public"):
super().__init__(name, depends) super().__init__(name, depends)
self.table = table self.table = table
self.schema = schema self.schema = "public"
def check(self): def check(self):
log_check.info("Checking %s", self.name) log_check.info("Checking %s", self.name)
@ -202,7 +198,7 @@ class HaveColumn(Node):
self.table = table self.table = table
self.column = column self.column = column
self.definition = definition self.definition = definition
self.schema = schema self.schema = "public"
def check(self): def check(self):
log_check.info("Checking %s", self.name) log_check.info("Checking %s", self.name)
@ -356,22 +352,16 @@ def main():
ap = argparse.ArgumentParser() ap = argparse.ArgumentParser()
ap.add_argument("--dbname") ap.add_argument("--dbname")
ap.add_argument("--dbuser") ap.add_argument("--dbuser")
ap.add_argument("--schema", default="public")
ap.add_argument("files", nargs="+") ap.add_argument("files", nargs="+")
args = ap.parse_args() args = ap.parse_args()
db = psycopg2.connect(dbname=args.dbname, user=args.dbuser) db = psycopg2.connect(dbname=args.dbname, user=args.dbuser)
csr = db.cursor()
csr.execute("show search_path")
search_path = csr.fetchone()[0]
search_path = args.schema + ", " + search_path
csr.execute(f"set search_path to {search_path}")
rules = [] rules = []
for f in args.files: for f in args.files:
with open(f) as rf: with open(f) as rf:
text = rf.read() text = rf.read()
ps = parser.ParseState(text, schema=args.schema) ps = parser.ParseState(text)
ps2 = parser.parse_ruleset(ps) ps2 = parser.parse_ruleset(ps)

View File

@ -21,14 +21,13 @@ class Failure:
class ParseState: class ParseState:
def __init__(self, text, position=0, schema="public"): def __init__(self, text, position=0):
self.text = text self.text = text
self.position = position self.position = position
self.schema = schema # the default schema. probably doesn't belong into the parser state
self.child_failure = None self.child_failure = None
def clone(self): def clone(self):
ps = ParseState(self.text, self.position, self.schema) ps = ParseState(self.text, self.position)
return ps return ps
@property @property
@ -43,7 +42,7 @@ class ParseState:
linesbefore = self.text[:position].split("\n") linesbefore = self.text[:position].split("\n")
linesafter = self.text[position:].split("\n") linesafter = self.text[position:].split("\n")
good = "\x1B[40;32m" good = "\x1B[40;32m"
bad = "\x1B[40;31;1m" bad = "\x1B[40;31m"
reset = "\x1B[0m" reset = "\x1B[0m"
s = reset + message + "\n" s = reset + message + "\n"
lines = [] lines = []
@ -79,8 +78,7 @@ def parse_ruleset(ps):
ps3 = parse_table_rule(ps2) or \ ps3 = parse_table_rule(ps2) or \
parse_column_rule(ps2) or \ parse_column_rule(ps2) or \
parse_data_rule(ps2) or \ parse_data_rule(ps2) or \
parse_index_rule(ps2) or \ parse_index_rule(ps2)
parse_view_rule(ps2)
if ps3: if ps3:
ps2.ast.append(ps3.ast) ps2.ast.append(ps3.ast)
ps2.position = ps3.position ps2.position = ps3.position
@ -100,16 +98,8 @@ def parse_table_rule(ps):
if not ps3: if not ps3:
ps.record_child_failure(ps2, "expected table name") ps.record_child_failure(ps2, "expected table name")
return return
if len(ps3.ast) == 2:
schema_name = ps3.ast[0]
table_name = ps3.ast[1]
elif len(ps3.ast) == 1:
schema_name = ps3.schema
table_name = ps3.ast[0]
else:
assert(False)
ps2.ast = procrusql.HaveTable(rulename(), [], ps3.ast[0], schema=schema_name) ps2.ast = procrusql.HaveTable(rulename(), [], ps3.ast[0])
ps2.position = ps3.position ps2.position = ps3.position
return ps2 return ps2
@ -129,14 +119,7 @@ def parse_column_rule(ps):
if not ps3: if not ps3:
ps.record_child_failure(ps2, "expected table name") ps.record_child_failure(ps2, "expected table name")
return return
if len(ps3.ast) == 2: table_name = ps3.ast[0]
schema_name = ps3.ast[0]
table_name = ps3.ast[1]
elif len(ps3.ast) == 1:
schema_name = ps3.schema
table_name = ps3.ast[0]
else:
assert(False)
ps2.position = ps3.position ps2.position = ps3.position
ps3 = parse_column_name(ps2) ps3 = parse_column_name(ps2)
@ -153,9 +136,7 @@ def parse_column_rule(ps):
column_definition = ps3.ast[0] column_definition = ps3.ast[0]
ps2.position = ps3.position ps2.position = ps3.position
ps2.ast = procrusql.HaveColumn( ps2.ast = procrusql.HaveColumn(rulename(), [], table_name, column_name, column_definition)
rulename(), [],
table_name, column_name, column_definition, schema=schema_name)
ps2.match_newlines() ps2.match_newlines()
@ -172,14 +153,7 @@ def parse_data_rule(ps):
if not ps3: if not ps3:
ps.record_child_failure(ps2, "expected table name") ps.record_child_failure(ps2, "expected table name")
return return
if len(ps3.ast) == 2: table_name = ps3.ast[0]
schema_name = ps3.ast[0]
table_name = ps3.ast[1]
elif len(ps3.ast) == 1:
schema_name = ps3.schema
table_name = ps3.ast[0]
else:
assert(False)
ps2.position = ps3.position ps2.position = ps3.position
if ps3 := parse_dict(ps2): if ps3 := parse_dict(ps2):
@ -200,9 +174,7 @@ def parse_data_rule(ps):
else: else:
label = rulename() label = rulename()
ps2.ast = procrusql.HaveData( ps2.ast = procrusql.HaveData(label, [], table_name, key_data, extra_data)
label, [],
table_name, key_data, extra_data, schema=schema_name)
ps2.match_newlines() ps2.match_newlines()
elif ps3 := parse_init_query(ps2): elif ps3 := parse_init_query(ps2):
@ -243,14 +215,7 @@ def parse_index_rule(ps):
if not ps3: if not ps3:
ps.record_child_failure(ps2, "expected table name") ps.record_child_failure(ps2, "expected table name")
return return
if len(ps3.ast) == 2: table_name = ps3.ast[0]
schema_name = ps3.ast[0]
table_name = ps3.ast[1]
elif len(ps3.ast) == 1:
schema_name = ps3.schema
table_name = ps3.ast[0]
else:
assert(False)
ps2.position = ps3.position ps2.position = ps3.position
m = ps2.match(r"\s*(using\b|\([\w, ]+\))[^>\n]*") m = ps2.match(r"\s*(using\b|\([\w, ]+\))[^>\n]*")
@ -265,29 +230,12 @@ def parse_index_rule(ps):
else: else:
label = rulename() label = rulename()
ps2.ast = procrusql.HaveIndex( ps2.ast = procrusql.HaveIndex(label, [], table_name, index_name, index_type, index_definition)
label, [],
table_name, index_name,
index_type, index_definition, schema=schema_name)
ps2.match_newlines() ps2.match_newlines()
return ps2 return ps2
def parse_view_rule(ps):
ps2 = ps.clone()
ps2.skip_whitespace_and_comments()
if not ps2.match(r"view\b"):
ps.record_child_failure(ps2, "expected “view”")
return
ps2.skip_whitespace_and_comments()
ps3 = parse_table_name(ps2)
if not ps3:
ps.record_child_failure(ps2, "expected view name")
return
ps2.skip_whitespace_and_comments()
ps3 = parse_multiline_string(ps2)
def parse_table_name(ps): def parse_table_name(ps):
@ -299,12 +247,6 @@ def parse_table_name(ps):
if ps2.rest[0].isalpha(): if ps2.rest[0].isalpha():
m = ps2.match(r"\w+") # always succeeds since we already checked the first character m = ps2.match(r"\w+") # always succeeds since we already checked the first character
ps2.ast.append(m.group(0)) ps2.ast.append(m.group(0))
if ps2.rest[0] == ".":
m = ps2.match(r"\w+")
if not m:
ps.record_child_failure(ps2, "expected table name after schema")
return
ps2.ast.append(m.group(0))
else: else:
ps.record_child_failure(ps2, "expected table name") ps.record_child_failure(ps2, "expected table name")
return ps2 return ps2
@ -343,21 +285,14 @@ def parse_column_definition(ps):
ps2 = ps.clone() ps2 = ps.clone()
ps2.ast = [] ps2.ast = []
ps2.skip_whitespace_and_comments() ps2.skip_whitespace_and_comments()
sqltypes = sorted( sqltypes = (
( "integer", "int", "serial", "bigint",
"integer", "int", "serial", "bigint", "boolean",
"boolean", "text", "character varying",
"text", "character varying", "date", "timestamp with time zone", "timestamptz",
"date", "timestamp with time zone", "timestamptz", "time",
"time", "inet",
"inet", "double precision", "float8", "real", "float4",
"double precision", "float8", "real", "float4",
"json", "jsonb",
"uuid",
r"integer\[\]", r"int\[\]", r"bigint\[\]",
"bytea",
),
key=lambda x: -len(x) # longest match first
) )
pattern = "(" + "|".join(sqltypes) + ")" + r"([ \t]+(default .*|not null\b|primary key\b|unique\b|references \w+\b))*" pattern = "(" + "|".join(sqltypes) + ")" + r"([ \t]+(default .*|not null\b|primary key\b|unique\b|references \w+\b))*"
m = ps2.match(pattern) m = ps2.match(pattern)