221 lines
7.5 KiB
Python
221 lines
7.5 KiB
Python
import logging
|
|
|
|
import psycopg2
|
|
from psycopg2 import sql
|
|
from psycopg2 import extras
|
|
|
|
log = logging.getLogger(__name__)
|
|
log_action = log.getChild("action")
|
|
log_check = log.getChild("check")
|
|
log_state = log.getChild("state")
|
|
|
|
class Node:
|
|
def __init__(self, name, depends):
|
|
self.name = name
|
|
self.depends = depends
|
|
self.ok = False
|
|
self.ready = False
|
|
|
|
def __repr__(self):
|
|
return f"{type(self)}({self.name}{' ready' if self.ready else ''}{' ok' if self.ok else ''})"
|
|
|
|
def is_ready(self):
|
|
# XXX - Naive O(n²) algorithm
|
|
if self.ready:
|
|
return True
|
|
for d in self.depends:
|
|
found = False
|
|
for w in want:
|
|
if w.name == d:
|
|
if not w.ok:
|
|
log_state.info("%s depends on %s which is not yet ok", self.name, d)
|
|
return False
|
|
found = True
|
|
break
|
|
if not found:
|
|
raise RuntimeError(f"Dependency {d} of {self.name} doesn't exist")
|
|
log_state.info("%s is now ready", self.name)
|
|
self.ready = True
|
|
return True
|
|
|
|
class HaveData(Node):
|
|
def __init__(self, name, depends, table, key, extra):
|
|
super().__init__(name, depends)
|
|
self.table = table
|
|
self.key = key
|
|
self.extra = extra
|
|
|
|
def check(self):
|
|
log_check.info("Checking %s", self.name)
|
|
csr = db.cursor(cursor_factory=extras.DictCursor)
|
|
key_checks = [
|
|
sql.SQL(" = ").join([ sql.Identifier(x), sql.Placeholder() ])
|
|
for x in self.key.keys()
|
|
]
|
|
key_check = sql.SQL(" and ").join(key_checks)
|
|
q = sql.SQL(
|
|
"select * from {table} where {key_check}"
|
|
).format(
|
|
table=sql.Identifier(self.table),
|
|
key_check=key_check
|
|
)
|
|
key_values = [v.resolve() if isinstance(v, Ref) else v for v in self.key.values()]
|
|
csr.execute(q, key_values)
|
|
self.result = csr.fetchall()
|
|
log_check.info("Got %d rows", len(self.result))
|
|
if self.result:
|
|
self.ok = True
|
|
log_state.info("%s is now ok", self.name)
|
|
return
|
|
|
|
extra_values = [v.resolve() if isinstance(v, Ref) else v for v in self.extra.values()]
|
|
columns = list(self.key.keys()) + list(self.extra.keys())
|
|
values = key_values + extra_values
|
|
q = sql.SQL(
|
|
"insert into {table}({columns}) values({placeholders}) returning *"
|
|
).format(
|
|
table=sql.Identifier(self.table),
|
|
columns=sql.SQL(", ").join([sql.Identifier(x) for x in columns]),
|
|
placeholders=sql.SQL(", ").join([sql.Placeholder() for x in columns]),
|
|
)
|
|
log_action.info("Inserting data")
|
|
csr.execute(q, values)
|
|
self.result = csr.fetchall()
|
|
log_action.info("Got %d rows", len(self.result))
|
|
if self.result:
|
|
self.ok = True
|
|
log_state.info("%s is now ok", self.name)
|
|
return
|
|
# We shouldn't get here. Either the insert succeeded, or it raised an
|
|
# exception. Success with 0 rows should not happen.
|
|
raise RuntimeError("Unreachable code reached")
|
|
|
|
class HaveTable(Node):
|
|
|
|
def __init__(self, name, depends, table, schema="public"):
|
|
super().__init__(name, depends)
|
|
self.table = table
|
|
self.schema = "public"
|
|
|
|
def check(self):
|
|
log_check.info("Checking %s", self.name)
|
|
csr = db.cursor(cursor_factory=extras.DictCursor)
|
|
csr.execute(
|
|
"""
|
|
select * from information_schema.tables
|
|
where table_schema = %s and table_name = %s
|
|
""",
|
|
(self.schema, self.table,))
|
|
r = csr.fetchall()
|
|
if len(r) == 1:
|
|
# Table exists, all ok
|
|
self.ok = True
|
|
log_state.info("%s is now ok", self.name)
|
|
return
|
|
if len(r) > 1:
|
|
raise RuntimeError(f"Found {len(r)} tables with schema {self.schema} and name {self.table}")
|
|
|
|
# Create table
|
|
# (Yes, we can actually create a table with 0 columns)
|
|
log_action.info("Creating table %s.%s", self.schema, self.table)
|
|
q = sql.SQL("create table {schema}.{table}()").format(
|
|
schema=sql.Identifier(self.schema),
|
|
table=sql.Identifier(self.table)
|
|
)
|
|
csr.execute(q)
|
|
self.ok = True
|
|
log_state.info("%s is now ok", self.name)
|
|
|
|
pass
|
|
|
|
class HaveColumn(Node):
|
|
# hjp=> alter table service add id serial primary key;
|
|
# ALTER TABLE
|
|
# hjp=> alter table service add type text;
|
|
# ALTER TABLE
|
|
# ...
|
|
def __init__(self, name, depends, table, column, definition, schema="public"):
|
|
super().__init__(name, depends)
|
|
self.table = table
|
|
self.column = column
|
|
self.definition = definition
|
|
self.schema = "public"
|
|
|
|
def check(self):
|
|
log_check.info("Checking %s", self.name)
|
|
csr = db.cursor(cursor_factory=extras.DictCursor)
|
|
# For now just check if column exists. Checking the type etc. will be implemented later
|
|
csr.execute(
|
|
"""
|
|
select * from information_schema.columns
|
|
where table_schema = %s and table_name = %s and column_name = %s
|
|
""",
|
|
(self.schema, self.table, self.column, ))
|
|
r = csr.fetchall()
|
|
if len(r) == 1:
|
|
# Column exists, all ok
|
|
self.ok = True
|
|
log_state.info("%s is now ok", self.name)
|
|
return
|
|
if len(r) > 1:
|
|
raise RuntimeError(f"Found {len(r)} columns with nam {self.columnr} in {self.schema}.{self.table}")
|
|
|
|
# Create column
|
|
log_action.info("Adding column %s to table %s.%s", self.column, self.schema, self.table)
|
|
q = sql.SQL("alter table {schema}.{table} add {column} {definition}").format(
|
|
schema=sql.Identifier(self.schema),
|
|
table=sql.Identifier(self.table),
|
|
column=sql.Identifier(self.column),
|
|
definition=sql.SQL(self.definition),
|
|
)
|
|
csr.execute(q)
|
|
self.ok = True
|
|
log_state.info("%s is now ok", self.name)
|
|
|
|
class HaveUniqueConstraint(Node):
|
|
# hjp=> alter table service add unique (type, feature);
|
|
# ALTER TABLE
|
|
pass
|
|
|
|
def findnode(name):
|
|
for w in want:
|
|
if w.name == name:
|
|
return w
|
|
|
|
class Ref:
|
|
def __init__(self, datanode, row, column):
|
|
self.datanode = datanode
|
|
self.row = row
|
|
self.column = column
|
|
|
|
def resolve(self):
|
|
datanode = findnode(self.datanode)
|
|
if not datanode.ok:
|
|
# XXX - We might try to resolve this, but for now the user is responsible to declare the dependency explicitely
|
|
raise RuntimeError(f"Cannot get data from {datanode.name} which is not yet ok. Please add a dependency")
|
|
return datanode.result[self.row][self.column]
|
|
|
|
|
|
def fit(_db, _want):
|
|
global db, want
|
|
db = _db
|
|
want = _want
|
|
while True:
|
|
progress = False
|
|
not_ok = 0
|
|
for w in want:
|
|
if not w.ok:
|
|
if w.is_ready():
|
|
w.check()
|
|
progress = True
|
|
else:
|
|
not_ok += 1
|
|
if not_ok == 0:
|
|
break
|
|
if not progress:
|
|
raise RuntimeError(f"Didn't make any progress in this round, but {not_ok} requirements are still not ok")
|
|
log_state.info("%d requirements are not yet ok", not_ok)
|
|
|
|
db.commit()
|
|
log_state.info("Done")
|