#!/usr/bin/python3 from pprint import pprint import re def all_values(v, mask_1, mask_float): if mask_float: t = mask_float t |= t >> 1 t |= t >> 2 t |= t >> 4 t |= t >> 8 t |= t >> 16 t |= t >> 32 hi = t - (t >> 1) return all_values(v | hi, mask_1, mask_float & ~hi) + all_values(v & ~hi, mask_1, mask_float & ~hi) else: return [v | mask_1] mem = {} with open("input") as fh: for ln in fh: if m := re.match(r"mask = ([01X]{36})", ln): mask = m.group(1) mask_1 = int(mask.replace("X", "0"), base=2) mask_float = int(mask.replace("1", "0").replace("X", "1"), base=2) elif m:= re.match(r"mem\[(\d+)\] = (\d+)", ln): i = int(m.group(1)) v = int(m.group(2)) for i in all_values(i, mask_1, mask_float): mem[i] = v else: raise RuntimeError(ln) print(sum(mem.values()))