diff --git a/setup.cfg b/setup.cfg index 7d55f46..d0c6439 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = ProcruSQL -version = 0.0.4 +version = 0.0.5 author = Peter J. Holzer author_email = hjp@hjp.at description = Make a database fit its description diff --git a/src/procrusql/__init__.py b/src/procrusql/__init__.py index 52c6625..7f4a724 100644 --- a/src/procrusql/__init__.py +++ b/src/procrusql/__init__.py @@ -192,6 +192,50 @@ class HaveUniqueConstraint(Node): # ALTER TABLE pass +class HaveIndex(Node): + def __init__(self, name, depends, table, index, type, definition, schema="public"): + super().__init__(name, depends) + self.table = table + self.index = index + self.type = type + self.definition = definition + self.schema = schema + + def check(self): + log_check.info("Checking %s", self.name) + csr = db.cursor(cursor_factory=extras.DictCursor) + # For now just check if index exists. Checking the type etc. will be implemented later + csr.execute( + """ + select * from pg_indexes + where schemaname = %s and tablename = %s and indexname = %s + """, + (self.schema, self.table, self.index, )) + r = csr.fetchall() + if len(r) == 1: + # Index exists, all ok + self.set_order() + self.ok = True + log_state.info("%s is now ok", self.name) + return + if len(r) > 1: + raise RuntimeError(f"Found {len(r)} indexes with name {self.index} on {self.schema}.{self.table}") + + # Create index + log_action.info("Adding index %s to table %s.%s", self.index, self.schema, self.table) + q = sql.SQL("create {type} {index} on {schema}.{table} {definition}").format( + type=sql.SQL(self.type), + schema=sql.Identifier(self.schema), + table=sql.Identifier(self.table), + index=sql.Identifier(self.index), + definition=sql.SQL(self.definition), + ) + csr.execute(q) + self.set_order() + self.ok = True + log_state.info("%s is now ok", self.name) + + def findnode(name): for w in want: if w.name == name: diff --git a/src/procrusql/parser.py b/src/procrusql/parser.py index c087a56..7ade364 100755 --- a/src/procrusql/parser.py +++ b/src/procrusql/parser.py @@ -75,7 +75,10 @@ def parse_ruleset(ps): ps2 = ps.clone() ps2.ast = [] while ps2.rest: - ps3 = parse_table_rule(ps2) or parse_column_rule(ps2) or parse_data_rule(ps2) + ps3 = parse_table_rule(ps2) or \ + parse_column_rule(ps2) or \ + parse_data_rule(ps2) or \ + parse_index_rule(ps2) if ps3: ps2.ast.append(ps3.ast) ps2.position = ps3.position @@ -180,6 +183,52 @@ def parse_data_rule(ps): return ps2 +def parse_index_rule(ps): + ps2 = ps.clone() + ps2.skip_whitespace_and_comments() + m = ps2.match(r"(unique)? index\b") + if not m: + ps.record_child_failure(ps2, "expected “(unique) index”") + return + index_type = m.group(0) + + ps3 = parse_index_name(ps2) + if not ps3: + ps.record_child_failure(ps2, "expected index name") + return + index_name = ps3.ast[0] + ps2.position = ps3.position + if not ps2.match(r"\s+on\b"): + ps.record_child_failure(ps2, "expected “on”") + return + + ps3 = parse_table_name(ps2) + if not ps3: + ps.record_child_failure(ps2, "expected table name") + return + table_name = ps3.ast[0] + ps2.position = ps3.position + + m = ps2.match(r"\s*(using\b|\([\w, ]+\))[^>]*") + if not m: + ps.record_child_failure(ps2, "expected “using” or column list") + index_definition = m.group(0) + + ps3 = parse_label(ps2) + if ps3: + label = ps3.ast + ps2.position = ps3.position + else: + label = rulename() + + ps2.ast = procrusql.HaveIndex(label, [], table_name, index_name, index_type, index_definition) + + ps2.match_newlines() + + return ps2 + + + def parse_table_name(ps): # For now this matches only simple names, not schema-qualified names or # quoted names. @@ -208,6 +257,21 @@ def parse_column_name(ps): ps.record_child_failure(ps2, "expected column name") return +def parse_index_name(ps): + # this is an exact duplicate of parse_table_name and parse_column_name, but they will + # probably diverge, so I duplicated it. I probably should define a + # parse_identifier and redefine them in terms of it. + ps2 = ps.clone() + ps2.ast = [] + ps2.skip_whitespace_and_comments() + if ps2.rest[0].isalpha(): + m = ps2.match(r"\w+") # always succeeds since we already checked the first character + ps2.ast.append(m.group(0)) + return ps2 + else: + ps.record_child_failure(ps2, "expected index name") + return + def parse_column_definition(ps): ps2 = ps.clone() ps2.ast = []