import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def at_least_2_of_6(bits, weights): """Returns 1 if at least 2 of 6 inputs are high.""" inputs = torch.tensor([float(b) for b in bits]) return int((inputs @ weights['neuron.weight'].T + weights['neuron.bias'] >= 0).item()) if __name__ == '__main__': w = load_model() print('2-out-of-6 truth table by Hamming weight:') for hw in range(7): bits = [1 if i < hw else 0 for i in range(6)] result = at_least_2_of_6(bits, w) expected = 1 if hw >= 2 else 0 status = 'OK' if result == expected else 'FAIL' print(f' HW={hw}: {result} (expected {expected}) {status}')