Compare commits

...

9 Commits

Author SHA1 Message Date
Peter J. Holzer 05e43907a5 Set version to 0.16 2025-03-14 12:44:31 +01:00
Peter J. Holzer 12dcf05eaf Merge branch 'master' of git.hjp.at:hjp/procrusql 2024-12-12 21:51:30 +01:00
Peter J. Holzer 43272bb96a Add schema support for data rules 2024-12-12 21:51:05 +01:00
Peter J. Holzer d31204cd5c Add view definitions 2024-12-12 21:50:25 +01:00
Peter J. Holzer 8f93c2e2da Add type bytea 2024-12-12 21:48:31 +01:00
Peter J. Holzer 363cdcede1 Make error messages more readable 2024-12-12 21:48:06 +01:00
Peter J. Holzer 7970835c8b Add schema support 2024-12-12 21:46:14 +01:00
Peter J. Holzer f6c64a50ab Add uuid 2024-02-06 12:47:00 +01:00
Peter J. Holzer ccf10d5690 Add integer arrays 2023-08-18 15:53:27 +02:00
3 changed files with 103 additions and 28 deletions

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = ProcruSQL name = ProcruSQL
version = 0.0.12 version = 0.0.16
author = Peter J. Holzer author = Peter J. Holzer
author_email = hjp@hjp.at author_email = hjp@hjp.at
description = Make a database fit its description description = Make a database fit its description
@ -9,6 +9,7 @@ long_description_content_type = text/markdown
url = https://git.hjp.at:3000/hjp/procrusql url = https://git.hjp.at:3000/hjp/procrusql
project_urls = project_urls =
Bug Tracker = https://git.hjp.at:3000/hjp/procrusql/issues Bug Tracker = https://git.hjp.at:3000/hjp/procrusql/issues
Repository = https://git.hjp.at:3000/hjp/procrusql
classifiers = classifiers =
Programming Language :: Python :: 3 Programming Language :: Python :: 3
License :: OSI Approved :: MIT License License :: OSI Approved :: MIT License

View File

@ -48,11 +48,12 @@ class Node:
self.order = order self.order = order
class HaveData(Node): class HaveData(Node):
def __init__(self, name, depends, table, key, extra): def __init__(self, name, depends, table, key, extra, schema="public"):
super().__init__(name, depends) super().__init__(name, depends)
self.table = table self.table = table
self.key = key self.key = key
self.extra = extra self.extra = extra
self.schema = schema
def check(self): def check(self):
log_check.info("Checking %s", self.name) log_check.info("Checking %s", self.name)
@ -63,8 +64,9 @@ class HaveData(Node):
] ]
key_check = sql.SQL(" and ").join(key_checks) key_check = sql.SQL(" and ").join(key_checks)
q = sql.SQL( q = sql.SQL(
"select * from {table} where {key_check}" "select * from {schema}.{table} where {key_check}"
).format( ).format(
schema=sql.Identifier(self.schema),
table=sql.Identifier(self.table), table=sql.Identifier(self.table),
key_check=key_check key_check=key_check
) )
@ -78,8 +80,9 @@ class HaveData(Node):
if self.result[0][c] != self.extra[c]: if self.result[0][c] != self.extra[c]:
log_action.info("Updating %s: %s <- %s", key_values, c, self.extra[c]) log_action.info("Updating %s: %s <- %s", key_values, c, self.extra[c])
q = sql.SQL( q = sql.SQL(
"update {table} set {column}={placeholder} where {key_check}" "update {schema}.{table} set {column}={placeholder} where {key_check}"
).format( ).format(
schema=sql.Identifier(self.schema),
table=sql.Identifier(self.table), table=sql.Identifier(self.table),
column=sql.Identifier(c), column=sql.Identifier(c),
placeholder=sql.Placeholder(), placeholder=sql.Placeholder(),
@ -96,8 +99,9 @@ class HaveData(Node):
columns = list(self.key.keys()) + list(self.extra.keys()) columns = list(self.key.keys()) + list(self.extra.keys())
values = key_values + extra_values values = key_values + extra_values
q = sql.SQL( q = sql.SQL(
"insert into {table}({columns}) values({placeholders}) returning *" "insert into {schema}.{table}({columns}) values({placeholders}) returning *"
).format( ).format(
schema=sql.Identifier(self.schema),
table=sql.Identifier(self.table), table=sql.Identifier(self.table),
columns=sql.SQL(", ").join([sql.Identifier(x) for x in columns]), columns=sql.SQL(", ").join([sql.Identifier(x) for x in columns]),
placeholders=sql.SQL(", ").join([sql.Placeholder() for x in columns]), placeholders=sql.SQL(", ").join([sql.Placeholder() for x in columns]),
@ -150,7 +154,7 @@ class HaveTable(Node):
def __init__(self, name, depends, table, schema="public"): def __init__(self, name, depends, table, schema="public"):
super().__init__(name, depends) super().__init__(name, depends)
self.table = table self.table = table
self.schema = "public" self.schema = schema
def check(self): def check(self):
log_check.info("Checking %s", self.name) log_check.info("Checking %s", self.name)
@ -198,7 +202,7 @@ class HaveColumn(Node):
self.table = table self.table = table
self.column = column self.column = column
self.definition = definition self.definition = definition
self.schema = "public" self.schema = schema
def check(self): def check(self):
log_check.info("Checking %s", self.name) log_check.info("Checking %s", self.name)
@ -352,16 +356,22 @@ def main():
ap = argparse.ArgumentParser() ap = argparse.ArgumentParser()
ap.add_argument("--dbname") ap.add_argument("--dbname")
ap.add_argument("--dbuser") ap.add_argument("--dbuser")
ap.add_argument("--schema", default="public")
ap.add_argument("files", nargs="+") ap.add_argument("files", nargs="+")
args = ap.parse_args() args = ap.parse_args()
db = psycopg2.connect(dbname=args.dbname, user=args.dbuser) db = psycopg2.connect(dbname=args.dbname, user=args.dbuser)
csr = db.cursor()
csr.execute("show search_path")
search_path = csr.fetchone()[0]
search_path = args.schema + ", " + search_path
csr.execute(f"set search_path to {search_path}")
rules = [] rules = []
for f in args.files: for f in args.files:
with open(f) as rf: with open(f) as rf:
text = rf.read() text = rf.read()
ps = parser.ParseState(text) ps = parser.ParseState(text, schema=args.schema)
ps2 = parser.parse_ruleset(ps) ps2 = parser.parse_ruleset(ps)

View File

@ -21,13 +21,14 @@ class Failure:
class ParseState: class ParseState:
def __init__(self, text, position=0): def __init__(self, text, position=0, schema="public"):
self.text = text self.text = text
self.position = position self.position = position
self.schema = schema # the default schema. probably doesn't belong into the parser state
self.child_failure = None self.child_failure = None
def clone(self): def clone(self):
ps = ParseState(self.text, self.position) ps = ParseState(self.text, self.position, self.schema)
return ps return ps
@property @property
@ -42,7 +43,7 @@ class ParseState:
linesbefore = self.text[:position].split("\n") linesbefore = self.text[:position].split("\n")
linesafter = self.text[position:].split("\n") linesafter = self.text[position:].split("\n")
good = "\x1B[40;32m" good = "\x1B[40;32m"
bad = "\x1B[40;31m" bad = "\x1B[40;31;1m"
reset = "\x1B[0m" reset = "\x1B[0m"
s = reset + message + "\n" s = reset + message + "\n"
lines = [] lines = []
@ -78,7 +79,8 @@ def parse_ruleset(ps):
ps3 = parse_table_rule(ps2) or \ ps3 = parse_table_rule(ps2) or \
parse_column_rule(ps2) or \ parse_column_rule(ps2) or \
parse_data_rule(ps2) or \ parse_data_rule(ps2) or \
parse_index_rule(ps2) parse_index_rule(ps2) or \
parse_view_rule(ps2)
if ps3: if ps3:
ps2.ast.append(ps3.ast) ps2.ast.append(ps3.ast)
ps2.position = ps3.position ps2.position = ps3.position
@ -98,8 +100,16 @@ def parse_table_rule(ps):
if not ps3: if not ps3:
ps.record_child_failure(ps2, "expected table name") ps.record_child_failure(ps2, "expected table name")
return return
if len(ps3.ast) == 2:
schema_name = ps3.ast[0]
table_name = ps3.ast[1]
elif len(ps3.ast) == 1:
schema_name = ps3.schema
table_name = ps3.ast[0]
else:
assert(False)
ps2.ast = procrusql.HaveTable(rulename(), [], ps3.ast[0]) ps2.ast = procrusql.HaveTable(rulename(), [], ps3.ast[0], schema=schema_name)
ps2.position = ps3.position ps2.position = ps3.position
return ps2 return ps2
@ -119,7 +129,14 @@ def parse_column_rule(ps):
if not ps3: if not ps3:
ps.record_child_failure(ps2, "expected table name") ps.record_child_failure(ps2, "expected table name")
return return
table_name = ps3.ast[0] if len(ps3.ast) == 2:
schema_name = ps3.ast[0]
table_name = ps3.ast[1]
elif len(ps3.ast) == 1:
schema_name = ps3.schema
table_name = ps3.ast[0]
else:
assert(False)
ps2.position = ps3.position ps2.position = ps3.position
ps3 = parse_column_name(ps2) ps3 = parse_column_name(ps2)
@ -136,7 +153,9 @@ def parse_column_rule(ps):
column_definition = ps3.ast[0] column_definition = ps3.ast[0]
ps2.position = ps3.position ps2.position = ps3.position
ps2.ast = procrusql.HaveColumn(rulename(), [], table_name, column_name, column_definition) ps2.ast = procrusql.HaveColumn(
rulename(), [],
table_name, column_name, column_definition, schema=schema_name)
ps2.match_newlines() ps2.match_newlines()
@ -153,7 +172,14 @@ def parse_data_rule(ps):
if not ps3: if not ps3:
ps.record_child_failure(ps2, "expected table name") ps.record_child_failure(ps2, "expected table name")
return return
table_name = ps3.ast[0] if len(ps3.ast) == 2:
schema_name = ps3.ast[0]
table_name = ps3.ast[1]
elif len(ps3.ast) == 1:
schema_name = ps3.schema
table_name = ps3.ast[0]
else:
assert(False)
ps2.position = ps3.position ps2.position = ps3.position
if ps3 := parse_dict(ps2): if ps3 := parse_dict(ps2):
@ -174,7 +200,9 @@ def parse_data_rule(ps):
else: else:
label = rulename() label = rulename()
ps2.ast = procrusql.HaveData(label, [], table_name, key_data, extra_data) ps2.ast = procrusql.HaveData(
label, [],
table_name, key_data, extra_data, schema=schema_name)
ps2.match_newlines() ps2.match_newlines()
elif ps3 := parse_init_query(ps2): elif ps3 := parse_init_query(ps2):
@ -215,7 +243,14 @@ def parse_index_rule(ps):
if not ps3: if not ps3:
ps.record_child_failure(ps2, "expected table name") ps.record_child_failure(ps2, "expected table name")
return return
table_name = ps3.ast[0] if len(ps3.ast) == 2:
schema_name = ps3.ast[0]
table_name = ps3.ast[1]
elif len(ps3.ast) == 1:
schema_name = ps3.schema
table_name = ps3.ast[0]
else:
assert(False)
ps2.position = ps3.position ps2.position = ps3.position
m = ps2.match(r"\s*(using\b|\([\w, ]+\))[^>\n]*") m = ps2.match(r"\s*(using\b|\([\w, ]+\))[^>\n]*")
@ -230,12 +265,29 @@ def parse_index_rule(ps):
else: else:
label = rulename() label = rulename()
ps2.ast = procrusql.HaveIndex(label, [], table_name, index_name, index_type, index_definition) ps2.ast = procrusql.HaveIndex(
label, [],
table_name, index_name,
index_type, index_definition, schema=schema_name)
ps2.match_newlines() ps2.match_newlines()
return ps2 return ps2
def parse_view_rule(ps):
ps2 = ps.clone()
ps2.skip_whitespace_and_comments()
if not ps2.match(r"view\b"):
ps.record_child_failure(ps2, "expected “view”")
return
ps2.skip_whitespace_and_comments()
ps3 = parse_table_name(ps2)
if not ps3:
ps.record_child_failure(ps2, "expected view name")
return
ps2.skip_whitespace_and_comments()
ps3 = parse_multiline_string(ps2)
def parse_table_name(ps): def parse_table_name(ps):
@ -247,6 +299,12 @@ def parse_table_name(ps):
if ps2.rest[0].isalpha(): if ps2.rest[0].isalpha():
m = ps2.match(r"\w+") # always succeeds since we already checked the first character m = ps2.match(r"\w+") # always succeeds since we already checked the first character
ps2.ast.append(m.group(0)) ps2.ast.append(m.group(0))
if ps2.rest[0] == ".":
m = ps2.match(r"\w+")
if not m:
ps.record_child_failure(ps2, "expected table name after schema")
return
ps2.ast.append(m.group(0))
else: else:
ps.record_child_failure(ps2, "expected table name") ps.record_child_failure(ps2, "expected table name")
return ps2 return ps2
@ -285,15 +343,21 @@ def parse_column_definition(ps):
ps2 = ps.clone() ps2 = ps.clone()
ps2.ast = [] ps2.ast = []
ps2.skip_whitespace_and_comments() ps2.skip_whitespace_and_comments()
sqltypes = ( sqltypes = sorted(
"integer", "int", "serial", "bigint", (
"boolean", "integer", "int", "serial", "bigint",
"text", "character varying", "boolean",
"date", "timestamp with time zone", "timestamptz", "text", "character varying",
"time", "date", "timestamp with time zone", "timestamptz",
"inet", "time",
"double precision", "float8", "real", "float4", "inet",
"json", "jsonb", "double precision", "float8", "real", "float4",
"json", "jsonb",
"uuid",
r"integer\[\]", r"int\[\]", r"bigint\[\]",
"bytea",
),
key=lambda x: -len(x) # longest match first
) )
pattern = "(" + "|".join(sqltypes) + ")" + r"([ \t]+(default .*|not null\b|primary key\b|unique\b|references \w+\b))*" pattern = "(" + "|".join(sqltypes) + ")" + r"([ \t]+(default .*|not null\b|primary key\b|unique\b|references \w+\b))*"
m = ps2.match(pattern) m = ps2.match(pattern)