Add schema support for data rules

This commit is contained in:
Peter J. Holzer 2024-12-12 21:51:05 +01:00 committed by Peter J. Holzer
parent d31204cd5c
commit 43272bb96a
1 changed files with 8 additions and 4 deletions

View File

@ -48,11 +48,12 @@ class Node:
self.order = order self.order = order
class HaveData(Node): class HaveData(Node):
def __init__(self, name, depends, table, key, extra): def __init__(self, name, depends, table, key, extra, schema="public"):
super().__init__(name, depends) super().__init__(name, depends)
self.table = table self.table = table
self.key = key self.key = key
self.extra = extra self.extra = extra
self.schema = schema
def check(self): def check(self):
log_check.info("Checking %s", self.name) log_check.info("Checking %s", self.name)
@ -63,8 +64,9 @@ class HaveData(Node):
] ]
key_check = sql.SQL(" and ").join(key_checks) key_check = sql.SQL(" and ").join(key_checks)
q = sql.SQL( q = sql.SQL(
"select * from {table} where {key_check}" "select * from {schema}.{table} where {key_check}"
).format( ).format(
schema=sql.Identifier(self.schema),
table=sql.Identifier(self.table), table=sql.Identifier(self.table),
key_check=key_check key_check=key_check
) )
@ -78,8 +80,9 @@ class HaveData(Node):
if self.result[0][c] != self.extra[c]: if self.result[0][c] != self.extra[c]:
log_action.info("Updating %s: %s <- %s", key_values, c, self.extra[c]) log_action.info("Updating %s: %s <- %s", key_values, c, self.extra[c])
q = sql.SQL( q = sql.SQL(
"update {table} set {column}={placeholder} where {key_check}" "update {schema}.{table} set {column}={placeholder} where {key_check}"
).format( ).format(
schema=sql.Identifier(self.schema),
table=sql.Identifier(self.table), table=sql.Identifier(self.table),
column=sql.Identifier(c), column=sql.Identifier(c),
placeholder=sql.Placeholder(), placeholder=sql.Placeholder(),
@ -96,8 +99,9 @@ class HaveData(Node):
columns = list(self.key.keys()) + list(self.extra.keys()) columns = list(self.key.keys()) + list(self.extra.keys())
values = key_values + extra_values values = key_values + extra_values
q = sql.SQL( q = sql.SQL(
"insert into {table}({columns}) values({placeholders}) returning *" "insert into {schema}.{table}({columns}) values({placeholders}) returning *"
).format( ).format(
schema=sql.Identifier(self.schema),
table=sql.Identifier(self.table), table=sql.Identifier(self.table),
columns=sql.SQL(", ").join([sql.Identifier(x) for x in columns]), columns=sql.SQL(", ").join([sql.Identifier(x) for x in columns]),
placeholders=sql.SQL(", ").join([sql.Placeholder() for x in columns]), placeholders=sql.SQL(", ").join([sql.Placeholder() for x in columns]),