Compare commits

..

20 Commits

Author SHA1 Message Date
Peter J. Holzer 05e43907a5 Set version to 0.16 2025-03-14 12:44:31 +01:00
Peter J. Holzer 12dcf05eaf Merge branch 'master' of git.hjp.at:hjp/procrusql 2024-12-12 21:51:30 +01:00
Peter J. Holzer 43272bb96a Add schema support for data rules 2024-12-12 21:51:05 +01:00
Peter J. Holzer d31204cd5c Add view definitions 2024-12-12 21:50:25 +01:00
Peter J. Holzer 8f93c2e2da Add type bytea 2024-12-12 21:48:31 +01:00
Peter J. Holzer 363cdcede1 Make error messages more readable 2024-12-12 21:48:06 +01:00
Peter J. Holzer 7970835c8b Add schema support 2024-12-12 21:46:14 +01:00
Peter J. Holzer f6c64a50ab Add uuid 2024-02-06 12:47:00 +01:00
Peter J. Holzer ccf10d5690 Add integer arrays 2023-08-18 15:53:27 +02:00
Peter J. Holzer 6c3437a5f1 Add JSON types 2023-08-18 15:21:06 +02:00
Peter J. Holzer f187fe6dba Bump version number 2023-03-03 07:10:46 +01:00
Peter J. Holzer 4c639df91d Add some more types 2023-03-03 07:05:30 +01:00
Peter J. Holzer b64ec8b8b8 Accept additional SQL type time 2022-11-05 10:36:43 +01:00
Peter J. Holzer 4039a905ae Drop or set NOT NULL if necessary 2022-09-19 14:37:29 +02:00
Peter J. Holzer 2da6559fc2 Allow column constraints in any order 2022-09-08 12:06:14 +02:00
Peter J. Holzer 2cd66bdf50 Initialize table from SQL query 2022-08-02 14:49:07 +02:00
Peter J. Holzer faf4770c5d Update extra columns on existing rows 2022-07-18 15:58:39 +02:00
Peter J. Holzer 2665e56e6f Restrict column and index definitions to one line 2022-05-05 11:27:24 +02:00
Peter J. Holzer f71ce98af5 Add index rules 2022-05-04 15:24:00 +02:00
Peter J. Holzer 33b680c49b Accept some additional SQL types
Still far from complete ...
2022-04-19 14:59:56 +02:00
3 changed files with 323 additions and 51 deletions

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = ProcruSQL name = ProcruSQL
version = 0.0.3 version = 0.0.16
author = Peter J. Holzer author = Peter J. Holzer
author_email = hjp@hjp.at author_email = hjp@hjp.at
description = Make a database fit its description description = Make a database fit its description
@ -9,6 +9,7 @@ long_description_content_type = text/markdown
url = https://git.hjp.at:3000/hjp/procrusql url = https://git.hjp.at:3000/hjp/procrusql
project_urls = project_urls =
Bug Tracker = https://git.hjp.at:3000/hjp/procrusql/issues Bug Tracker = https://git.hjp.at:3000/hjp/procrusql/issues
Repository = https://git.hjp.at:3000/hjp/procrusql
classifiers = classifiers =
Programming Language :: Python :: 3 Programming Language :: Python :: 3
License :: OSI Approved :: MIT License License :: OSI Approved :: MIT License

View File

@ -48,11 +48,12 @@ class Node:
self.order = order self.order = order
class HaveData(Node): class HaveData(Node):
def __init__(self, name, depends, table, key, extra): def __init__(self, name, depends, table, key, extra, schema="public"):
super().__init__(name, depends) super().__init__(name, depends)
self.table = table self.table = table
self.key = key self.key = key
self.extra = extra self.extra = extra
self.schema = schema
def check(self): def check(self):
log_check.info("Checking %s", self.name) log_check.info("Checking %s", self.name)
@ -63,8 +64,9 @@ class HaveData(Node):
] ]
key_check = sql.SQL(" and ").join(key_checks) key_check = sql.SQL(" and ").join(key_checks)
q = sql.SQL( q = sql.SQL(
"select * from {table} where {key_check}" "select * from {schema}.{table} where {key_check}"
).format( ).format(
schema=sql.Identifier(self.schema),
table=sql.Identifier(self.table), table=sql.Identifier(self.table),
key_check=key_check key_check=key_check
) )
@ -73,17 +75,33 @@ class HaveData(Node):
self.result = csr.fetchall() self.result = csr.fetchall()
log_check.info("Got %d rows", len(self.result)) log_check.info("Got %d rows", len(self.result))
if self.result: if self.result:
extra_columns = list(self.extra.keys())
for c in extra_columns:
if self.result[0][c] != self.extra[c]:
log_action.info("Updating %s: %s <- %s", key_values, c, self.extra[c])
q = sql.SQL(
"update {schema}.{table} set {column}={placeholder} where {key_check}"
).format(
schema=sql.Identifier(self.schema),
table=sql.Identifier(self.table),
column=sql.Identifier(c),
placeholder=sql.Placeholder(),
key_check=key_check,
)
csr.execute(q, [self.extra[c]] + key_values)
self.result[0][c] = self.extra[c]
self.set_order() self.set_order()
self.ok = True self.ok = True
log_state.info("%s is now ok", self.name) log_state.info("%s is now ok", self.name)
return return
else:
extra_values = [v.resolve() if isinstance(v, Ref) else v for v in self.extra.values()] extra_values = [v.resolve() if isinstance(v, Ref) else v for v in self.extra.values()]
columns = list(self.key.keys()) + list(self.extra.keys()) columns = list(self.key.keys()) + list(self.extra.keys())
values = key_values + extra_values values = key_values + extra_values
q = sql.SQL( q = sql.SQL(
"insert into {table}({columns}) values({placeholders}) returning *" "insert into {schema}.{table}({columns}) values({placeholders}) returning *"
).format( ).format(
schema=sql.Identifier(self.schema),
table=sql.Identifier(self.table), table=sql.Identifier(self.table),
columns=sql.SQL(", ").join([sql.Identifier(x) for x in columns]), columns=sql.SQL(", ").join([sql.Identifier(x) for x in columns]),
placeholders=sql.SQL(", ").join([sql.Placeholder() for x in columns]), placeholders=sql.SQL(", ").join([sql.Placeholder() for x in columns]),
@ -101,12 +119,42 @@ class HaveData(Node):
# exception. Success with 0 rows should not happen. # exception. Success with 0 rows should not happen.
raise RuntimeError("Unreachable code reached") raise RuntimeError("Unreachable code reached")
class HaveInit(Node):
def __init__(self, name, depends, table, query):
super().__init__(name, depends)
self.table = table
self.query = query
def check(self):
log_check.info("Checking %s", self.name)
for w in want:
if isinstance(w, HaveTable) and w.table == self.table:
if not w.ok:
log_check.error("%s not yet ok", w.name)
raise RuntimeError(f"Cannot insert into table {w.table} from {w.name} which is not yet ok. Please add a dependency")
if w.new:
log_action.info("Executing %s", self.query)
csr = db.cursor(cursor_factory=extras.DictCursor)
csr.execute(self.query)
self.result = csr.fetchall()
log_action.info("Got %d rows", len(self.result))
self.set_order()
else:
log_check.info("Table %s already exists", w.table)
self.ok = True
log_state.info("%s is now ok", self.name)
break
else:
raise RuntimeError(f"Cannot find a rule which creates table {self.table}")
class HaveTable(Node): class HaveTable(Node):
def __init__(self, name, depends, table, schema="public"): def __init__(self, name, depends, table, schema="public"):
super().__init__(name, depends) super().__init__(name, depends)
self.table = table self.table = table
self.schema = "public" self.schema = schema
def check(self): def check(self):
log_check.info("Checking %s", self.name) log_check.info("Checking %s", self.name)
@ -122,6 +170,7 @@ class HaveTable(Node):
# Table exists, all ok # Table exists, all ok
self.set_order() self.set_order()
self.ok = True self.ok = True
self.new = False
log_state.info("%s is now ok", self.name) log_state.info("%s is now ok", self.name)
return return
if len(r) > 1: if len(r) > 1:
@ -137,6 +186,7 @@ class HaveTable(Node):
csr.execute(q) csr.execute(q)
self.set_order() self.set_order()
self.ok = True self.ok = True
self.new = True
log_state.info("%s is now ok", self.name) log_state.info("%s is now ok", self.name)
pass pass
@ -152,7 +202,7 @@ class HaveColumn(Node):
self.table = table self.table = table
self.column = column self.column = column
self.definition = definition self.definition = definition
self.schema = "public" self.schema = schema
def check(self): def check(self):
log_check.info("Checking %s", self.name) log_check.info("Checking %s", self.name)
@ -166,7 +216,18 @@ class HaveColumn(Node):
(self.schema, self.table, self.column, )) (self.schema, self.table, self.column, ))
r = csr.fetchall() r = csr.fetchall()
if len(r) == 1: if len(r) == 1:
# Column exists, all ok # Column exists, check attributes
if (r[0]["is_nullable"] == "YES") != self.definition["nullable"]:
log_action.info("Changing column %s of %s.%s to %s",
self.column, self.schema, self.table,
"null" if self.definition["nullable"] else "not null")
q = sql.SQL("alter table {schema}.{table} alter {column} {action} not null").format(
schema=sql.Identifier(self.schema),
table=sql.Identifier(self.table),
column=sql.Identifier(self.column),
action=sql.SQL("drop" if self.definition["nullable"] else "set")
)
csr.execute(q)
self.set_order() self.set_order()
self.ok = True self.ok = True
log_state.info("%s is now ok", self.name) log_state.info("%s is now ok", self.name)
@ -180,7 +241,7 @@ class HaveColumn(Node):
schema=sql.Identifier(self.schema), schema=sql.Identifier(self.schema),
table=sql.Identifier(self.table), table=sql.Identifier(self.table),
column=sql.Identifier(self.column), column=sql.Identifier(self.column),
definition=sql.SQL(self.definition), definition=sql.SQL(self.definition["text"]),
) )
csr.execute(q) csr.execute(q)
self.set_order() self.set_order()
@ -192,6 +253,50 @@ class HaveUniqueConstraint(Node):
# ALTER TABLE # ALTER TABLE
pass pass
class HaveIndex(Node):
def __init__(self, name, depends, table, index, type, definition, schema="public"):
super().__init__(name, depends)
self.table = table
self.index = index
self.type = type
self.definition = definition
self.schema = schema
def check(self):
log_check.info("Checking %s", self.name)
csr = db.cursor(cursor_factory=extras.DictCursor)
# For now just check if index exists. Checking the type etc. will be implemented later
csr.execute(
"""
select * from pg_indexes
where schemaname = %s and tablename = %s and indexname = %s
""",
(self.schema, self.table, self.index, ))
r = csr.fetchall()
if len(r) == 1:
# Index exists, all ok
self.set_order()
self.ok = True
log_state.info("%s is now ok", self.name)
return
if len(r) > 1:
raise RuntimeError(f"Found {len(r)} indexes with name {self.index} on {self.schema}.{self.table}")
# Create index
log_action.info("Adding index %s to table %s.%s", self.index, self.schema, self.table)
q = sql.SQL("create {type} {index} on {schema}.{table} {definition}").format(
type=sql.SQL(self.type),
schema=sql.Identifier(self.schema),
table=sql.Identifier(self.table),
index=sql.Identifier(self.index),
definition=sql.SQL(self.definition),
)
csr.execute(q)
self.set_order()
self.ok = True
log_state.info("%s is now ok", self.name)
def findnode(name): def findnode(name):
for w in want: for w in want:
if w.name == name: if w.name == name:
@ -251,16 +356,22 @@ def main():
ap = argparse.ArgumentParser() ap = argparse.ArgumentParser()
ap.add_argument("--dbname") ap.add_argument("--dbname")
ap.add_argument("--dbuser") ap.add_argument("--dbuser")
ap.add_argument("--schema", default="public")
ap.add_argument("files", nargs="+") ap.add_argument("files", nargs="+")
args = ap.parse_args() args = ap.parse_args()
db = psycopg2.connect(dbname=args.dbname, user=args.dbuser) db = psycopg2.connect(dbname=args.dbname, user=args.dbuser)
csr = db.cursor()
csr.execute("show search_path")
search_path = csr.fetchone()[0]
search_path = args.schema + ", " + search_path
csr.execute(f"set search_path to {search_path}")
rules = [] rules = []
for f in args.files: for f in args.files:
with open(f) as rf: with open(f) as rf:
text = rf.read() text = rf.read()
ps = parser.ParseState(text) ps = parser.ParseState(text, schema=args.schema)
ps2 = parser.parse_ruleset(ps) ps2 = parser.parse_ruleset(ps)

View File

@ -21,13 +21,14 @@ class Failure:
class ParseState: class ParseState:
def __init__(self, text, position=0): def __init__(self, text, position=0, schema="public"):
self.text = text self.text = text
self.position = position self.position = position
self.schema = schema # the default schema. probably doesn't belong into the parser state
self.child_failure = None self.child_failure = None
def clone(self): def clone(self):
ps = ParseState(self.text, self.position) ps = ParseState(self.text, self.position, self.schema)
return ps return ps
@property @property
@ -42,7 +43,7 @@ class ParseState:
linesbefore = self.text[:position].split("\n") linesbefore = self.text[:position].split("\n")
linesafter = self.text[position:].split("\n") linesafter = self.text[position:].split("\n")
good = "\x1B[40;32m" good = "\x1B[40;32m"
bad = "\x1B[40;31m" bad = "\x1B[40;31;1m"
reset = "\x1B[0m" reset = "\x1B[0m"
s = reset + message + "\n" s = reset + message + "\n"
lines = [] lines = []
@ -75,7 +76,11 @@ def parse_ruleset(ps):
ps2 = ps.clone() ps2 = ps.clone()
ps2.ast = [] ps2.ast = []
while ps2.rest: while ps2.rest:
ps3 = parse_table_rule(ps2) or parse_column_rule(ps2) or parse_data_rule(ps2) ps3 = parse_table_rule(ps2) or \
parse_column_rule(ps2) or \
parse_data_rule(ps2) or \
parse_index_rule(ps2) or \
parse_view_rule(ps2)
if ps3: if ps3:
ps2.ast.append(ps3.ast) ps2.ast.append(ps3.ast)
ps2.position = ps3.position ps2.position = ps3.position
@ -95,8 +100,16 @@ def parse_table_rule(ps):
if not ps3: if not ps3:
ps.record_child_failure(ps2, "expected table name") ps.record_child_failure(ps2, "expected table name")
return return
if len(ps3.ast) == 2:
schema_name = ps3.ast[0]
table_name = ps3.ast[1]
elif len(ps3.ast) == 1:
schema_name = ps3.schema
table_name = ps3.ast[0]
else:
assert(False)
ps2.ast = procrusql.HaveTable(rulename(), [], ps3.ast[0]) ps2.ast = procrusql.HaveTable(rulename(), [], ps3.ast[0], schema=schema_name)
ps2.position = ps3.position ps2.position = ps3.position
return ps2 return ps2
@ -116,7 +129,14 @@ def parse_column_rule(ps):
if not ps3: if not ps3:
ps.record_child_failure(ps2, "expected table name") ps.record_child_failure(ps2, "expected table name")
return return
if len(ps3.ast) == 2:
schema_name = ps3.ast[0]
table_name = ps3.ast[1]
elif len(ps3.ast) == 1:
schema_name = ps3.schema
table_name = ps3.ast[0] table_name = ps3.ast[0]
else:
assert(False)
ps2.position = ps3.position ps2.position = ps3.position
ps3 = parse_column_name(ps2) ps3 = parse_column_name(ps2)
@ -133,7 +153,9 @@ def parse_column_rule(ps):
column_definition = ps3.ast[0] column_definition = ps3.ast[0]
ps2.position = ps3.position ps2.position = ps3.position
ps2.ast = procrusql.HaveColumn(rulename(), [], table_name, column_name, column_definition) ps2.ast = procrusql.HaveColumn(
rulename(), [],
table_name, column_name, column_definition, schema=schema_name)
ps2.match_newlines() ps2.match_newlines()
@ -150,13 +172,17 @@ def parse_data_rule(ps):
if not ps3: if not ps3:
ps.record_child_failure(ps2, "expected table name") ps.record_child_failure(ps2, "expected table name")
return return
if len(ps3.ast) == 2:
schema_name = ps3.ast[0]
table_name = ps3.ast[1]
elif len(ps3.ast) == 1:
schema_name = ps3.schema
table_name = ps3.ast[0] table_name = ps3.ast[0]
else:
assert(False)
ps2.position = ps3.position ps2.position = ps3.position
ps3 = parse_dict(ps2) if ps3 := parse_dict(ps2):
if not ps3:
ps.record_child_failure(ps2, "expected key data definition")
return
key_data = ps3.ast key_data = ps3.ast
ps2.position = ps3.position ps2.position = ps3.position
@ -174,12 +200,96 @@ def parse_data_rule(ps):
else: else:
label = rulename() label = rulename()
ps2.ast = procrusql.HaveData(label, [], table_name, key_data, extra_data) ps2.ast = procrusql.HaveData(
label, [],
table_name, key_data, extra_data, schema=schema_name)
ps2.match_newlines()
elif ps3 := parse_init_query(ps2):
# We have a bit of a problem here: The query extends to the end of the
# line so there is no room for a label. I really don't want to parse
# SQL here.
label = rulename()
ps2.position = ps3.position
ps2.ast = procrusql.HaveInit(label, [], table_name, ps3.ast)
ps2.match_newlines()
else:
ps.record_child_failure(ps2, "expected key data definition or insert query")
return
return ps2
def parse_index_rule(ps):
ps2 = ps.clone()
ps2.skip_whitespace_and_comments()
m = ps2.match(r"(unique)? index\b")
if not m:
ps.record_child_failure(ps2, "expected “(unique) index”")
return
index_type = m.group(0)
ps3 = parse_index_name(ps2)
if not ps3:
ps.record_child_failure(ps2, "expected index name")
return
index_name = ps3.ast[0]
ps2.position = ps3.position
if not ps2.match(r"\s+on\b"):
ps.record_child_failure(ps2, "expected “on”")
return
ps3 = parse_table_name(ps2)
if not ps3:
ps.record_child_failure(ps2, "expected table name")
return
if len(ps3.ast) == 2:
schema_name = ps3.ast[0]
table_name = ps3.ast[1]
elif len(ps3.ast) == 1:
schema_name = ps3.schema
table_name = ps3.ast[0]
else:
assert(False)
ps2.position = ps3.position
m = ps2.match(r"\s*(using\b|\([\w, ]+\))[^>\n]*")
if not m:
ps.record_child_failure(ps2, "expected “using” or column list")
index_definition = m.group(0)
ps3 = parse_label(ps2)
if ps3:
label = ps3.ast
ps2.position = ps3.position
else:
label = rulename()
ps2.ast = procrusql.HaveIndex(
label, [],
table_name, index_name,
index_type, index_definition, schema=schema_name)
ps2.match_newlines() ps2.match_newlines()
return ps2 return ps2
def parse_view_rule(ps):
ps2 = ps.clone()
ps2.skip_whitespace_and_comments()
if not ps2.match(r"view\b"):
ps.record_child_failure(ps2, "expected “view”")
return
ps2.skip_whitespace_and_comments()
ps3 = parse_table_name(ps2)
if not ps3:
ps.record_child_failure(ps2, "expected view name")
return
ps2.skip_whitespace_and_comments()
ps3 = parse_multiline_string(ps2)
def parse_table_name(ps): def parse_table_name(ps):
# For now this matches only simple names, not schema-qualified names or # For now this matches only simple names, not schema-qualified names or
# quoted names. # quoted names.
@ -189,6 +299,12 @@ def parse_table_name(ps):
if ps2.rest[0].isalpha(): if ps2.rest[0].isalpha():
m = ps2.match(r"\w+") # always succeeds since we already checked the first character m = ps2.match(r"\w+") # always succeeds since we already checked the first character
ps2.ast.append(m.group(0)) ps2.ast.append(m.group(0))
if ps2.rest[0] == ".":
m = ps2.match(r"\w+")
if not m:
ps.record_child_failure(ps2, "expected table name after schema")
return
ps2.ast.append(m.group(0))
else: else:
ps.record_child_failure(ps2, "expected table name") ps.record_child_failure(ps2, "expected table name")
return ps2 return ps2
@ -208,15 +324,49 @@ def parse_column_name(ps):
ps.record_child_failure(ps2, "expected column name") ps.record_child_failure(ps2, "expected column name")
return return
def parse_index_name(ps):
# this is an exact duplicate of parse_table_name and parse_column_name, but they will
# probably diverge, so I duplicated it. I probably should define a
# parse_identifier and redefine them in terms of it.
ps2 = ps.clone()
ps2.ast = []
ps2.skip_whitespace_and_comments()
if ps2.rest[0].isalpha():
m = ps2.match(r"\w+") # always succeeds since we already checked the first character
ps2.ast.append(m.group(0))
return ps2
else:
ps.record_child_failure(ps2, "expected index name")
return
def parse_column_definition(ps): def parse_column_definition(ps):
ps2 = ps.clone() ps2 = ps.clone()
ps2.ast = [] ps2.ast = []
ps2.skip_whitespace_and_comments() ps2.skip_whitespace_and_comments()
m = ps2.match(r"(int|serial|text|boolean)(\s+not null)?(\s+(primary key|unique|references \w+))?\b") sqltypes = sorted(
(
"integer", "int", "serial", "bigint",
"boolean",
"text", "character varying",
"date", "timestamp with time zone", "timestamptz",
"time",
"inet",
"double precision", "float8", "real", "float4",
"json", "jsonb",
"uuid",
r"integer\[\]", r"int\[\]", r"bigint\[\]",
"bytea",
),
key=lambda x: -len(x) # longest match first
)
pattern = "(" + "|".join(sqltypes) + ")" + r"([ \t]+(default .*|not null\b|primary key\b|unique\b|references \w+\b))*"
m = ps2.match(pattern)
if not m: if not m:
ps.record_child_failure(ps2, "expected column definition") ps.record_child_failure(ps2, "expected column definition")
return return
ps2.ast.append(m.group(0)) text = m.group(0)
nullable = not("not null" in text or "primary key" in text)
ps2.ast.append({ "text": text, "nullable": nullable })
return ps2 return ps2
def parse_dict(ps): def parse_dict(ps):
@ -276,6 +426,16 @@ def parse_dict(ps):
ps2.ast = d ps2.ast = d
return ps2 return ps2
def parse_init_query(ps):
ps2 = ps.clone()
ps2.skip_whitespace_and_comments()
if m := ps2.match(r'(?i:with|insert)\b.*'):
ps2.ast = m.group(0)
return ps2
else:
ps.record_child_failure(ps2, "expected insert query")
return
def parse_label(ps): def parse_label(ps):
ps2 = ps.clone() ps2 = ps.clone()
if m := ps2.match(r"\s*>>\s*(\w+)"): if m := ps2.match(r"\s*>>\s*(\w+)"):