diff --git a/src/procrusql/__init__.py b/src/procrusql/__init__.py index 7f4a724..de1ccee 100644 --- a/src/procrusql/__init__.py +++ b/src/procrusql/__init__.py @@ -73,32 +73,46 @@ class HaveData(Node): self.result = csr.fetchall() log_check.info("Got %d rows", len(self.result)) if self.result: + extra_columns = list(self.extra.keys()) + for c in extra_columns: + if self.result[0][c] != self.extra[c]: + log_action.info("Updating %s: %s <- %s", key_values, c, self.extra[c]) + q = sql.SQL( + "update {table} set {column}={placeholder} where {key_check}" + ).format( + table=sql.Identifier(self.table), + column=sql.Identifier(c), + placeholder=sql.Placeholder(), + key_check=key_check, + ) + csr.execute(q, [self.extra[c]] + key_values) + self.result[0][c] = self.extra[c] self.set_order() self.ok = True log_state.info("%s is now ok", self.name) return - - extra_values = [v.resolve() if isinstance(v, Ref) else v for v in self.extra.values()] - columns = list(self.key.keys()) + list(self.extra.keys()) - values = key_values + extra_values - q = sql.SQL( - "insert into {table}({columns}) values({placeholders}) returning *" - ).format( - table=sql.Identifier(self.table), - columns=sql.SQL(", ").join([sql.Identifier(x) for x in columns]), - placeholders=sql.SQL(", ").join([sql.Placeholder() for x in columns]), - ) - log_action.info("Inserting data") - csr.execute(q, values) - self.result = csr.fetchall() - log_action.info("Got %d rows", len(self.result)) - if self.result: - self.set_order() - self.ok = True - log_state.info("%s is now ok", self.name) - return - # We shouldn't get here. Either the insert succeeded, or it raised an - # exception. Success with 0 rows should not happen. + else: + extra_values = [v.resolve() if isinstance(v, Ref) else v for v in self.extra.values()] + columns = list(self.key.keys()) + list(self.extra.keys()) + values = key_values + extra_values + q = sql.SQL( + "insert into {table}({columns}) values({placeholders}) returning *" + ).format( + table=sql.Identifier(self.table), + columns=sql.SQL(", ").join([sql.Identifier(x) for x in columns]), + placeholders=sql.SQL(", ").join([sql.Placeholder() for x in columns]), + ) + log_action.info("Inserting data") + csr.execute(q, values) + self.result = csr.fetchall() + log_action.info("Got %d rows", len(self.result)) + if self.result: + self.set_order() + self.ok = True + log_state.info("%s is now ok", self.name) + return + # We shouldn't get here. Either the insert succeeded, or it raised an + # exception. Success with 0 rows should not happen. raise RuntimeError("Unreachable code reached") class HaveTable(Node):