Compare commits
20 Commits
Author | SHA1 | Date |
---|---|---|
|
05e43907a5 | |
|
12dcf05eaf | |
|
43272bb96a | |
|
d31204cd5c | |
|
8f93c2e2da | |
|
363cdcede1 | |
|
7970835c8b | |
|
f6c64a50ab | |
|
ccf10d5690 | |
|
6c3437a5f1 | |
|
f187fe6dba | |
|
4c639df91d | |
|
b64ec8b8b8 | |
|
4039a905ae | |
|
2da6559fc2 | |
|
2cd66bdf50 | |
|
faf4770c5d | |
|
2665e56e6f | |
|
f71ce98af5 | |
|
33b680c49b |
|
@ -1,6 +1,6 @@
|
||||||
[metadata]
|
[metadata]
|
||||||
name = ProcruSQL
|
name = ProcruSQL
|
||||||
version = 0.0.3
|
version = 0.0.16
|
||||||
author = Peter J. Holzer
|
author = Peter J. Holzer
|
||||||
author_email = hjp@hjp.at
|
author_email = hjp@hjp.at
|
||||||
description = Make a database fit its description
|
description = Make a database fit its description
|
||||||
|
@ -9,6 +9,7 @@ long_description_content_type = text/markdown
|
||||||
url = https://git.hjp.at:3000/hjp/procrusql
|
url = https://git.hjp.at:3000/hjp/procrusql
|
||||||
project_urls =
|
project_urls =
|
||||||
Bug Tracker = https://git.hjp.at:3000/hjp/procrusql/issues
|
Bug Tracker = https://git.hjp.at:3000/hjp/procrusql/issues
|
||||||
|
Repository = https://git.hjp.at:3000/hjp/procrusql
|
||||||
classifiers =
|
classifiers =
|
||||||
Programming Language :: Python :: 3
|
Programming Language :: Python :: 3
|
||||||
License :: OSI Approved :: MIT License
|
License :: OSI Approved :: MIT License
|
||||||
|
|
|
@ -48,11 +48,12 @@ class Node:
|
||||||
self.order = order
|
self.order = order
|
||||||
|
|
||||||
class HaveData(Node):
|
class HaveData(Node):
|
||||||
def __init__(self, name, depends, table, key, extra):
|
def __init__(self, name, depends, table, key, extra, schema="public"):
|
||||||
super().__init__(name, depends)
|
super().__init__(name, depends)
|
||||||
self.table = table
|
self.table = table
|
||||||
self.key = key
|
self.key = key
|
||||||
self.extra = extra
|
self.extra = extra
|
||||||
|
self.schema = schema
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
log_check.info("Checking %s", self.name)
|
log_check.info("Checking %s", self.name)
|
||||||
|
@ -63,8 +64,9 @@ class HaveData(Node):
|
||||||
]
|
]
|
||||||
key_check = sql.SQL(" and ").join(key_checks)
|
key_check = sql.SQL(" and ").join(key_checks)
|
||||||
q = sql.SQL(
|
q = sql.SQL(
|
||||||
"select * from {table} where {key_check}"
|
"select * from {schema}.{table} where {key_check}"
|
||||||
).format(
|
).format(
|
||||||
|
schema=sql.Identifier(self.schema),
|
||||||
table=sql.Identifier(self.table),
|
table=sql.Identifier(self.table),
|
||||||
key_check=key_check
|
key_check=key_check
|
||||||
)
|
)
|
||||||
|
@ -73,17 +75,33 @@ class HaveData(Node):
|
||||||
self.result = csr.fetchall()
|
self.result = csr.fetchall()
|
||||||
log_check.info("Got %d rows", len(self.result))
|
log_check.info("Got %d rows", len(self.result))
|
||||||
if self.result:
|
if self.result:
|
||||||
|
extra_columns = list(self.extra.keys())
|
||||||
|
for c in extra_columns:
|
||||||
|
if self.result[0][c] != self.extra[c]:
|
||||||
|
log_action.info("Updating %s: %s <- %s", key_values, c, self.extra[c])
|
||||||
|
q = sql.SQL(
|
||||||
|
"update {schema}.{table} set {column}={placeholder} where {key_check}"
|
||||||
|
).format(
|
||||||
|
schema=sql.Identifier(self.schema),
|
||||||
|
table=sql.Identifier(self.table),
|
||||||
|
column=sql.Identifier(c),
|
||||||
|
placeholder=sql.Placeholder(),
|
||||||
|
key_check=key_check,
|
||||||
|
)
|
||||||
|
csr.execute(q, [self.extra[c]] + key_values)
|
||||||
|
self.result[0][c] = self.extra[c]
|
||||||
self.set_order()
|
self.set_order()
|
||||||
self.ok = True
|
self.ok = True
|
||||||
log_state.info("%s is now ok", self.name)
|
log_state.info("%s is now ok", self.name)
|
||||||
return
|
return
|
||||||
|
else:
|
||||||
extra_values = [v.resolve() if isinstance(v, Ref) else v for v in self.extra.values()]
|
extra_values = [v.resolve() if isinstance(v, Ref) else v for v in self.extra.values()]
|
||||||
columns = list(self.key.keys()) + list(self.extra.keys())
|
columns = list(self.key.keys()) + list(self.extra.keys())
|
||||||
values = key_values + extra_values
|
values = key_values + extra_values
|
||||||
q = sql.SQL(
|
q = sql.SQL(
|
||||||
"insert into {table}({columns}) values({placeholders}) returning *"
|
"insert into {schema}.{table}({columns}) values({placeholders}) returning *"
|
||||||
).format(
|
).format(
|
||||||
|
schema=sql.Identifier(self.schema),
|
||||||
table=sql.Identifier(self.table),
|
table=sql.Identifier(self.table),
|
||||||
columns=sql.SQL(", ").join([sql.Identifier(x) for x in columns]),
|
columns=sql.SQL(", ").join([sql.Identifier(x) for x in columns]),
|
||||||
placeholders=sql.SQL(", ").join([sql.Placeholder() for x in columns]),
|
placeholders=sql.SQL(", ").join([sql.Placeholder() for x in columns]),
|
||||||
|
@ -101,12 +119,42 @@ class HaveData(Node):
|
||||||
# exception. Success with 0 rows should not happen.
|
# exception. Success with 0 rows should not happen.
|
||||||
raise RuntimeError("Unreachable code reached")
|
raise RuntimeError("Unreachable code reached")
|
||||||
|
|
||||||
|
class HaveInit(Node):
|
||||||
|
def __init__(self, name, depends, table, query):
|
||||||
|
super().__init__(name, depends)
|
||||||
|
self.table = table
|
||||||
|
self.query = query
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
log_check.info("Checking %s", self.name)
|
||||||
|
|
||||||
|
for w in want:
|
||||||
|
if isinstance(w, HaveTable) and w.table == self.table:
|
||||||
|
if not w.ok:
|
||||||
|
log_check.error("%s not yet ok", w.name)
|
||||||
|
raise RuntimeError(f"Cannot insert into table {w.table} from {w.name} which is not yet ok. Please add a dependency")
|
||||||
|
if w.new:
|
||||||
|
log_action.info("Executing %s", self.query)
|
||||||
|
csr = db.cursor(cursor_factory=extras.DictCursor)
|
||||||
|
csr.execute(self.query)
|
||||||
|
self.result = csr.fetchall()
|
||||||
|
log_action.info("Got %d rows", len(self.result))
|
||||||
|
self.set_order()
|
||||||
|
else:
|
||||||
|
log_check.info("Table %s already exists", w.table)
|
||||||
|
self.ok = True
|
||||||
|
log_state.info("%s is now ok", self.name)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Cannot find a rule which creates table {self.table}")
|
||||||
|
|
||||||
|
|
||||||
class HaveTable(Node):
|
class HaveTable(Node):
|
||||||
|
|
||||||
def __init__(self, name, depends, table, schema="public"):
|
def __init__(self, name, depends, table, schema="public"):
|
||||||
super().__init__(name, depends)
|
super().__init__(name, depends)
|
||||||
self.table = table
|
self.table = table
|
||||||
self.schema = "public"
|
self.schema = schema
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
log_check.info("Checking %s", self.name)
|
log_check.info("Checking %s", self.name)
|
||||||
|
@ -122,6 +170,7 @@ class HaveTable(Node):
|
||||||
# Table exists, all ok
|
# Table exists, all ok
|
||||||
self.set_order()
|
self.set_order()
|
||||||
self.ok = True
|
self.ok = True
|
||||||
|
self.new = False
|
||||||
log_state.info("%s is now ok", self.name)
|
log_state.info("%s is now ok", self.name)
|
||||||
return
|
return
|
||||||
if len(r) > 1:
|
if len(r) > 1:
|
||||||
|
@ -137,6 +186,7 @@ class HaveTable(Node):
|
||||||
csr.execute(q)
|
csr.execute(q)
|
||||||
self.set_order()
|
self.set_order()
|
||||||
self.ok = True
|
self.ok = True
|
||||||
|
self.new = True
|
||||||
log_state.info("%s is now ok", self.name)
|
log_state.info("%s is now ok", self.name)
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
@ -152,7 +202,7 @@ class HaveColumn(Node):
|
||||||
self.table = table
|
self.table = table
|
||||||
self.column = column
|
self.column = column
|
||||||
self.definition = definition
|
self.definition = definition
|
||||||
self.schema = "public"
|
self.schema = schema
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
log_check.info("Checking %s", self.name)
|
log_check.info("Checking %s", self.name)
|
||||||
|
@ -166,7 +216,18 @@ class HaveColumn(Node):
|
||||||
(self.schema, self.table, self.column, ))
|
(self.schema, self.table, self.column, ))
|
||||||
r = csr.fetchall()
|
r = csr.fetchall()
|
||||||
if len(r) == 1:
|
if len(r) == 1:
|
||||||
# Column exists, all ok
|
# Column exists, check attributes
|
||||||
|
if (r[0]["is_nullable"] == "YES") != self.definition["nullable"]:
|
||||||
|
log_action.info("Changing column %s of %s.%s to %s",
|
||||||
|
self.column, self.schema, self.table,
|
||||||
|
"null" if self.definition["nullable"] else "not null")
|
||||||
|
q = sql.SQL("alter table {schema}.{table} alter {column} {action} not null").format(
|
||||||
|
schema=sql.Identifier(self.schema),
|
||||||
|
table=sql.Identifier(self.table),
|
||||||
|
column=sql.Identifier(self.column),
|
||||||
|
action=sql.SQL("drop" if self.definition["nullable"] else "set")
|
||||||
|
)
|
||||||
|
csr.execute(q)
|
||||||
self.set_order()
|
self.set_order()
|
||||||
self.ok = True
|
self.ok = True
|
||||||
log_state.info("%s is now ok", self.name)
|
log_state.info("%s is now ok", self.name)
|
||||||
|
@ -180,7 +241,7 @@ class HaveColumn(Node):
|
||||||
schema=sql.Identifier(self.schema),
|
schema=sql.Identifier(self.schema),
|
||||||
table=sql.Identifier(self.table),
|
table=sql.Identifier(self.table),
|
||||||
column=sql.Identifier(self.column),
|
column=sql.Identifier(self.column),
|
||||||
definition=sql.SQL(self.definition),
|
definition=sql.SQL(self.definition["text"]),
|
||||||
)
|
)
|
||||||
csr.execute(q)
|
csr.execute(q)
|
||||||
self.set_order()
|
self.set_order()
|
||||||
|
@ -192,6 +253,50 @@ class HaveUniqueConstraint(Node):
|
||||||
# ALTER TABLE
|
# ALTER TABLE
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class HaveIndex(Node):
|
||||||
|
def __init__(self, name, depends, table, index, type, definition, schema="public"):
|
||||||
|
super().__init__(name, depends)
|
||||||
|
self.table = table
|
||||||
|
self.index = index
|
||||||
|
self.type = type
|
||||||
|
self.definition = definition
|
||||||
|
self.schema = schema
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
log_check.info("Checking %s", self.name)
|
||||||
|
csr = db.cursor(cursor_factory=extras.DictCursor)
|
||||||
|
# For now just check if index exists. Checking the type etc. will be implemented later
|
||||||
|
csr.execute(
|
||||||
|
"""
|
||||||
|
select * from pg_indexes
|
||||||
|
where schemaname = %s and tablename = %s and indexname = %s
|
||||||
|
""",
|
||||||
|
(self.schema, self.table, self.index, ))
|
||||||
|
r = csr.fetchall()
|
||||||
|
if len(r) == 1:
|
||||||
|
# Index exists, all ok
|
||||||
|
self.set_order()
|
||||||
|
self.ok = True
|
||||||
|
log_state.info("%s is now ok", self.name)
|
||||||
|
return
|
||||||
|
if len(r) > 1:
|
||||||
|
raise RuntimeError(f"Found {len(r)} indexes with name {self.index} on {self.schema}.{self.table}")
|
||||||
|
|
||||||
|
# Create index
|
||||||
|
log_action.info("Adding index %s to table %s.%s", self.index, self.schema, self.table)
|
||||||
|
q = sql.SQL("create {type} {index} on {schema}.{table} {definition}").format(
|
||||||
|
type=sql.SQL(self.type),
|
||||||
|
schema=sql.Identifier(self.schema),
|
||||||
|
table=sql.Identifier(self.table),
|
||||||
|
index=sql.Identifier(self.index),
|
||||||
|
definition=sql.SQL(self.definition),
|
||||||
|
)
|
||||||
|
csr.execute(q)
|
||||||
|
self.set_order()
|
||||||
|
self.ok = True
|
||||||
|
log_state.info("%s is now ok", self.name)
|
||||||
|
|
||||||
|
|
||||||
def findnode(name):
|
def findnode(name):
|
||||||
for w in want:
|
for w in want:
|
||||||
if w.name == name:
|
if w.name == name:
|
||||||
|
@ -251,16 +356,22 @@ def main():
|
||||||
ap = argparse.ArgumentParser()
|
ap = argparse.ArgumentParser()
|
||||||
ap.add_argument("--dbname")
|
ap.add_argument("--dbname")
|
||||||
ap.add_argument("--dbuser")
|
ap.add_argument("--dbuser")
|
||||||
|
ap.add_argument("--schema", default="public")
|
||||||
ap.add_argument("files", nargs="+")
|
ap.add_argument("files", nargs="+")
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
|
|
||||||
db = psycopg2.connect(dbname=args.dbname, user=args.dbuser)
|
db = psycopg2.connect(dbname=args.dbname, user=args.dbuser)
|
||||||
|
csr = db.cursor()
|
||||||
|
csr.execute("show search_path")
|
||||||
|
search_path = csr.fetchone()[0]
|
||||||
|
search_path = args.schema + ", " + search_path
|
||||||
|
csr.execute(f"set search_path to {search_path}")
|
||||||
|
|
||||||
rules = []
|
rules = []
|
||||||
for f in args.files:
|
for f in args.files:
|
||||||
with open(f) as rf:
|
with open(f) as rf:
|
||||||
text = rf.read()
|
text = rf.read()
|
||||||
ps = parser.ParseState(text)
|
ps = parser.ParseState(text, schema=args.schema)
|
||||||
|
|
||||||
ps2 = parser.parse_ruleset(ps)
|
ps2 = parser.parse_ruleset(ps)
|
||||||
|
|
||||||
|
|
|
@ -21,13 +21,14 @@ class Failure:
|
||||||
|
|
||||||
class ParseState:
|
class ParseState:
|
||||||
|
|
||||||
def __init__(self, text, position=0):
|
def __init__(self, text, position=0, schema="public"):
|
||||||
self.text = text
|
self.text = text
|
||||||
self.position = position
|
self.position = position
|
||||||
|
self.schema = schema # the default schema. probably doesn't belong into the parser state
|
||||||
self.child_failure = None
|
self.child_failure = None
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
ps = ParseState(self.text, self.position)
|
ps = ParseState(self.text, self.position, self.schema)
|
||||||
return ps
|
return ps
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -42,7 +43,7 @@ class ParseState:
|
||||||
linesbefore = self.text[:position].split("\n")
|
linesbefore = self.text[:position].split("\n")
|
||||||
linesafter = self.text[position:].split("\n")
|
linesafter = self.text[position:].split("\n")
|
||||||
good = "\x1B[40;32m"
|
good = "\x1B[40;32m"
|
||||||
bad = "\x1B[40;31m"
|
bad = "\x1B[40;31;1m"
|
||||||
reset = "\x1B[0m"
|
reset = "\x1B[0m"
|
||||||
s = reset + message + "\n"
|
s = reset + message + "\n"
|
||||||
lines = []
|
lines = []
|
||||||
|
@ -75,7 +76,11 @@ def parse_ruleset(ps):
|
||||||
ps2 = ps.clone()
|
ps2 = ps.clone()
|
||||||
ps2.ast = []
|
ps2.ast = []
|
||||||
while ps2.rest:
|
while ps2.rest:
|
||||||
ps3 = parse_table_rule(ps2) or parse_column_rule(ps2) or parse_data_rule(ps2)
|
ps3 = parse_table_rule(ps2) or \
|
||||||
|
parse_column_rule(ps2) or \
|
||||||
|
parse_data_rule(ps2) or \
|
||||||
|
parse_index_rule(ps2) or \
|
||||||
|
parse_view_rule(ps2)
|
||||||
if ps3:
|
if ps3:
|
||||||
ps2.ast.append(ps3.ast)
|
ps2.ast.append(ps3.ast)
|
||||||
ps2.position = ps3.position
|
ps2.position = ps3.position
|
||||||
|
@ -95,8 +100,16 @@ def parse_table_rule(ps):
|
||||||
if not ps3:
|
if not ps3:
|
||||||
ps.record_child_failure(ps2, "expected table name")
|
ps.record_child_failure(ps2, "expected table name")
|
||||||
return
|
return
|
||||||
|
if len(ps3.ast) == 2:
|
||||||
|
schema_name = ps3.ast[0]
|
||||||
|
table_name = ps3.ast[1]
|
||||||
|
elif len(ps3.ast) == 1:
|
||||||
|
schema_name = ps3.schema
|
||||||
|
table_name = ps3.ast[0]
|
||||||
|
else:
|
||||||
|
assert(False)
|
||||||
|
|
||||||
ps2.ast = procrusql.HaveTable(rulename(), [], ps3.ast[0])
|
ps2.ast = procrusql.HaveTable(rulename(), [], ps3.ast[0], schema=schema_name)
|
||||||
ps2.position = ps3.position
|
ps2.position = ps3.position
|
||||||
return ps2
|
return ps2
|
||||||
|
|
||||||
|
@ -116,7 +129,14 @@ def parse_column_rule(ps):
|
||||||
if not ps3:
|
if not ps3:
|
||||||
ps.record_child_failure(ps2, "expected table name")
|
ps.record_child_failure(ps2, "expected table name")
|
||||||
return
|
return
|
||||||
|
if len(ps3.ast) == 2:
|
||||||
|
schema_name = ps3.ast[0]
|
||||||
|
table_name = ps3.ast[1]
|
||||||
|
elif len(ps3.ast) == 1:
|
||||||
|
schema_name = ps3.schema
|
||||||
table_name = ps3.ast[0]
|
table_name = ps3.ast[0]
|
||||||
|
else:
|
||||||
|
assert(False)
|
||||||
ps2.position = ps3.position
|
ps2.position = ps3.position
|
||||||
|
|
||||||
ps3 = parse_column_name(ps2)
|
ps3 = parse_column_name(ps2)
|
||||||
|
@ -133,7 +153,9 @@ def parse_column_rule(ps):
|
||||||
column_definition = ps3.ast[0]
|
column_definition = ps3.ast[0]
|
||||||
ps2.position = ps3.position
|
ps2.position = ps3.position
|
||||||
|
|
||||||
ps2.ast = procrusql.HaveColumn(rulename(), [], table_name, column_name, column_definition)
|
ps2.ast = procrusql.HaveColumn(
|
||||||
|
rulename(), [],
|
||||||
|
table_name, column_name, column_definition, schema=schema_name)
|
||||||
|
|
||||||
ps2.match_newlines()
|
ps2.match_newlines()
|
||||||
|
|
||||||
|
@ -150,13 +172,17 @@ def parse_data_rule(ps):
|
||||||
if not ps3:
|
if not ps3:
|
||||||
ps.record_child_failure(ps2, "expected table name")
|
ps.record_child_failure(ps2, "expected table name")
|
||||||
return
|
return
|
||||||
|
if len(ps3.ast) == 2:
|
||||||
|
schema_name = ps3.ast[0]
|
||||||
|
table_name = ps3.ast[1]
|
||||||
|
elif len(ps3.ast) == 1:
|
||||||
|
schema_name = ps3.schema
|
||||||
table_name = ps3.ast[0]
|
table_name = ps3.ast[0]
|
||||||
|
else:
|
||||||
|
assert(False)
|
||||||
ps2.position = ps3.position
|
ps2.position = ps3.position
|
||||||
|
|
||||||
ps3 = parse_dict(ps2)
|
if ps3 := parse_dict(ps2):
|
||||||
if not ps3:
|
|
||||||
ps.record_child_failure(ps2, "expected key data definition")
|
|
||||||
return
|
|
||||||
key_data = ps3.ast
|
key_data = ps3.ast
|
||||||
ps2.position = ps3.position
|
ps2.position = ps3.position
|
||||||
|
|
||||||
|
@ -174,12 +200,96 @@ def parse_data_rule(ps):
|
||||||
else:
|
else:
|
||||||
label = rulename()
|
label = rulename()
|
||||||
|
|
||||||
ps2.ast = procrusql.HaveData(label, [], table_name, key_data, extra_data)
|
ps2.ast = procrusql.HaveData(
|
||||||
|
label, [],
|
||||||
|
table_name, key_data, extra_data, schema=schema_name)
|
||||||
|
|
||||||
|
ps2.match_newlines()
|
||||||
|
elif ps3 := parse_init_query(ps2):
|
||||||
|
# We have a bit of a problem here: The query extends to the end of the
|
||||||
|
# line so there is no room for a label. I really don't want to parse
|
||||||
|
# SQL here.
|
||||||
|
label = rulename()
|
||||||
|
ps2.position = ps3.position
|
||||||
|
ps2.ast = procrusql.HaveInit(label, [], table_name, ps3.ast)
|
||||||
|
|
||||||
|
ps2.match_newlines()
|
||||||
|
else:
|
||||||
|
ps.record_child_failure(ps2, "expected key data definition or insert query")
|
||||||
|
return
|
||||||
|
|
||||||
|
return ps2
|
||||||
|
|
||||||
|
def parse_index_rule(ps):
|
||||||
|
ps2 = ps.clone()
|
||||||
|
ps2.skip_whitespace_and_comments()
|
||||||
|
m = ps2.match(r"(unique)? index\b")
|
||||||
|
if not m:
|
||||||
|
ps.record_child_failure(ps2, "expected “(unique) index”")
|
||||||
|
return
|
||||||
|
index_type = m.group(0)
|
||||||
|
|
||||||
|
ps3 = parse_index_name(ps2)
|
||||||
|
if not ps3:
|
||||||
|
ps.record_child_failure(ps2, "expected index name")
|
||||||
|
return
|
||||||
|
index_name = ps3.ast[0]
|
||||||
|
ps2.position = ps3.position
|
||||||
|
if not ps2.match(r"\s+on\b"):
|
||||||
|
ps.record_child_failure(ps2, "expected “on”")
|
||||||
|
return
|
||||||
|
|
||||||
|
ps3 = parse_table_name(ps2)
|
||||||
|
if not ps3:
|
||||||
|
ps.record_child_failure(ps2, "expected table name")
|
||||||
|
return
|
||||||
|
if len(ps3.ast) == 2:
|
||||||
|
schema_name = ps3.ast[0]
|
||||||
|
table_name = ps3.ast[1]
|
||||||
|
elif len(ps3.ast) == 1:
|
||||||
|
schema_name = ps3.schema
|
||||||
|
table_name = ps3.ast[0]
|
||||||
|
else:
|
||||||
|
assert(False)
|
||||||
|
ps2.position = ps3.position
|
||||||
|
|
||||||
|
m = ps2.match(r"\s*(using\b|\([\w, ]+\))[^>\n]*")
|
||||||
|
if not m:
|
||||||
|
ps.record_child_failure(ps2, "expected “using” or column list")
|
||||||
|
index_definition = m.group(0)
|
||||||
|
|
||||||
|
ps3 = parse_label(ps2)
|
||||||
|
if ps3:
|
||||||
|
label = ps3.ast
|
||||||
|
ps2.position = ps3.position
|
||||||
|
else:
|
||||||
|
label = rulename()
|
||||||
|
|
||||||
|
ps2.ast = procrusql.HaveIndex(
|
||||||
|
label, [],
|
||||||
|
table_name, index_name,
|
||||||
|
index_type, index_definition, schema=schema_name)
|
||||||
|
|
||||||
ps2.match_newlines()
|
ps2.match_newlines()
|
||||||
|
|
||||||
return ps2
|
return ps2
|
||||||
|
|
||||||
|
def parse_view_rule(ps):
|
||||||
|
ps2 = ps.clone()
|
||||||
|
ps2.skip_whitespace_and_comments()
|
||||||
|
if not ps2.match(r"view\b"):
|
||||||
|
ps.record_child_failure(ps2, "expected “view”")
|
||||||
|
return
|
||||||
|
ps2.skip_whitespace_and_comments()
|
||||||
|
ps3 = parse_table_name(ps2)
|
||||||
|
if not ps3:
|
||||||
|
ps.record_child_failure(ps2, "expected view name")
|
||||||
|
return
|
||||||
|
ps2.skip_whitespace_and_comments()
|
||||||
|
ps3 = parse_multiline_string(ps2)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def parse_table_name(ps):
|
def parse_table_name(ps):
|
||||||
# For now this matches only simple names, not schema-qualified names or
|
# For now this matches only simple names, not schema-qualified names or
|
||||||
# quoted names.
|
# quoted names.
|
||||||
|
@ -189,6 +299,12 @@ def parse_table_name(ps):
|
||||||
if ps2.rest[0].isalpha():
|
if ps2.rest[0].isalpha():
|
||||||
m = ps2.match(r"\w+") # always succeeds since we already checked the first character
|
m = ps2.match(r"\w+") # always succeeds since we already checked the first character
|
||||||
ps2.ast.append(m.group(0))
|
ps2.ast.append(m.group(0))
|
||||||
|
if ps2.rest[0] == ".":
|
||||||
|
m = ps2.match(r"\w+")
|
||||||
|
if not m:
|
||||||
|
ps.record_child_failure(ps2, "expected table name after schema")
|
||||||
|
return
|
||||||
|
ps2.ast.append(m.group(0))
|
||||||
else:
|
else:
|
||||||
ps.record_child_failure(ps2, "expected table name")
|
ps.record_child_failure(ps2, "expected table name")
|
||||||
return ps2
|
return ps2
|
||||||
|
@ -208,15 +324,49 @@ def parse_column_name(ps):
|
||||||
ps.record_child_failure(ps2, "expected column name")
|
ps.record_child_failure(ps2, "expected column name")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def parse_index_name(ps):
|
||||||
|
# this is an exact duplicate of parse_table_name and parse_column_name, but they will
|
||||||
|
# probably diverge, so I duplicated it. I probably should define a
|
||||||
|
# parse_identifier and redefine them in terms of it.
|
||||||
|
ps2 = ps.clone()
|
||||||
|
ps2.ast = []
|
||||||
|
ps2.skip_whitespace_and_comments()
|
||||||
|
if ps2.rest[0].isalpha():
|
||||||
|
m = ps2.match(r"\w+") # always succeeds since we already checked the first character
|
||||||
|
ps2.ast.append(m.group(0))
|
||||||
|
return ps2
|
||||||
|
else:
|
||||||
|
ps.record_child_failure(ps2, "expected index name")
|
||||||
|
return
|
||||||
|
|
||||||
def parse_column_definition(ps):
|
def parse_column_definition(ps):
|
||||||
ps2 = ps.clone()
|
ps2 = ps.clone()
|
||||||
ps2.ast = []
|
ps2.ast = []
|
||||||
ps2.skip_whitespace_and_comments()
|
ps2.skip_whitespace_and_comments()
|
||||||
m = ps2.match(r"(int|serial|text|boolean)(\s+not null)?(\s+(primary key|unique|references \w+))?\b")
|
sqltypes = sorted(
|
||||||
|
(
|
||||||
|
"integer", "int", "serial", "bigint",
|
||||||
|
"boolean",
|
||||||
|
"text", "character varying",
|
||||||
|
"date", "timestamp with time zone", "timestamptz",
|
||||||
|
"time",
|
||||||
|
"inet",
|
||||||
|
"double precision", "float8", "real", "float4",
|
||||||
|
"json", "jsonb",
|
||||||
|
"uuid",
|
||||||
|
r"integer\[\]", r"int\[\]", r"bigint\[\]",
|
||||||
|
"bytea",
|
||||||
|
),
|
||||||
|
key=lambda x: -len(x) # longest match first
|
||||||
|
)
|
||||||
|
pattern = "(" + "|".join(sqltypes) + ")" + r"([ \t]+(default .*|not null\b|primary key\b|unique\b|references \w+\b))*"
|
||||||
|
m = ps2.match(pattern)
|
||||||
if not m:
|
if not m:
|
||||||
ps.record_child_failure(ps2, "expected column definition")
|
ps.record_child_failure(ps2, "expected column definition")
|
||||||
return
|
return
|
||||||
ps2.ast.append(m.group(0))
|
text = m.group(0)
|
||||||
|
nullable = not("not null" in text or "primary key" in text)
|
||||||
|
ps2.ast.append({ "text": text, "nullable": nullable })
|
||||||
return ps2
|
return ps2
|
||||||
|
|
||||||
def parse_dict(ps):
|
def parse_dict(ps):
|
||||||
|
@ -276,6 +426,16 @@ def parse_dict(ps):
|
||||||
ps2.ast = d
|
ps2.ast = d
|
||||||
return ps2
|
return ps2
|
||||||
|
|
||||||
|
def parse_init_query(ps):
|
||||||
|
ps2 = ps.clone()
|
||||||
|
ps2.skip_whitespace_and_comments()
|
||||||
|
if m := ps2.match(r'(?i:with|insert)\b.*'):
|
||||||
|
ps2.ast = m.group(0)
|
||||||
|
return ps2
|
||||||
|
else:
|
||||||
|
ps.record_child_failure(ps2, "expected insert query")
|
||||||
|
return
|
||||||
|
|
||||||
def parse_label(ps):
|
def parse_label(ps):
|
||||||
ps2 = ps.clone()
|
ps2 = ps.clone()
|
||||||
if m := ps2.match(r"\s*>>\s*(\w+)"):
|
if m := ps2.match(r"\s*>>\s*(\w+)"):
|
||||||
|
|
Loading…
Reference in New Issue