I found the problem from advent of code day 17 quite interesting.
Part 1
Part1 required a simple processor implementation. I modelled that as a mealy machine in Haskell.
module Main where
import Data.Bits
import System.Environment
type Register = Integer
type IP = Integer
type State = (Register, Register, Register, IP)
data OpCode = ADV | BXL | BST | JNZ | BXC | OUT | BDV | CDV
deriving Show
type Operand = Integer
type Instr = (Integer, Operand)
type Output = [Integer]
toOpCode :: Integer -> OpCode
toOpCode x = case x of
0 -> ADV
1 -> BXL
2 -> BST
3 -> JNZ
4 -> BXC
5 -> OUT
6 -> BDV
7 -> CDV
_ -> error "impossible"
combo :: Register -> Register -> Register -> Operand -> Operand
combo a b c x = case x of
4 -> a
5 -> b
6 -> c
_ -> x
mealyMachine :: (State, Instr) -> (State, Output)
mealyMachine (state, instr) = (nextState, output)
where
(a, b, c, ip) = state
(opcode, operand) = instr
(nextState, output) = case toOpCode opcode of
ADV -> ((na, b, c, ip + 2), [])
where na = a `div` (2 ^ combo a b c operand)
BXL -> ((a, nb, c, ip + 2), [])
where nb = xor b operand
BST -> ((a, nb, c, ip + 2), [])
where nb = mod (combo a b c operand) 8
JNZ -> ((a, b, c, nip), [])
where nip = if a == 0 then ip + 2 else operand
BXC -> ((a, nb, c, ip + 2), [])
where nb = xor b c
BDV -> ((a, nb, c, ip + 2), [])
where nb = a `div` (2 ^ combo a b c operand)
CDV -> ((a, b, nc, ip + 2), [])
where nc = a `div` (2 ^ combo a b c operand)
OUT -> ((a, b, c, ip + 2), [value])
where value = mod (combo a b c operand) 8
toInt :: Integer -> Int
toInt = fromIntegral
run :: [Integer] -> State -> [Integer] -> [Integer]
run program (a, b, c, ip) outs
| toInt ip >= length program = outs
| otherwise = allOutpts
where
(nextState, output) = mealyMachine ((a,b,c,ip), (program !! toInt ip, program !! (toInt ip + 1)))
allOutpts = run program nextState (outs ++ output)
program :: [Integer]
program = [2,4,1,1,7,5,1,4,0,3,4,5,5,5,3,0]
main :: IO()
main = do
args <- getArgs
let x = read (head args):: Integer
let initState = (x, 0, 0, 0) :: State
print $ run program initState []
Part 2
Problem
Part 2 asks us to find the value for a register A
such that when we run the program Program: 2,4,1,1,7,5,1,4,0,3,4,5,5,5,3,0
we get exactly the same output as the program itself. Just like quine programs.
For example, if we run the following program we get the exact output as 0,3,5,4,3,0
:
Register A: 117440
Register B: 0
Register C: 0
Program: 0,3,5,4,3,0
Solution
Since the number in register A can be arbitrarily large. We can’t just brute force to find the answer. So we need to find some way to reduce the search space.
To get some intuition, I ran the processor multiple times with increasing value of ‘A’.
./main 0 = [5]
./main 50 = [1, 3]
./main 100 = [2, 1, 5]
./main 150 = [2, 5, 7]
./main 200 = [1, 5, 6]
./main 250 = [0, 2, 6]
./main 300 = [0, 2, 1]
./main 350 = [1, 4, 0]
./main 400 = [5, 1, 3]
./main 450 = [7, 1, 2]
./main 500 = [6, 3, 2]
./main 550 = [7, 3, 1, 5]
./main 600 = [1, 4, 5, 5]
./main 650 = [6, 5, 6, 5]
./main 700 = [4, 3, 6, 5]
./main 750 = [6, 5, 4, 5]
./main 800 = [5, 2, 1, 5]
./main 850 = [5, 2, 0, 5]
./main 900 = [5, 5, 3, 5]
./main 950 = [4, 3, 3, 5]
./main 1000 = [1, 7, 2, 5]
./main 1050 = [4, 6, 5, 7]
./main 1100 = [3, 5, 5, 7]
./main 1150 = [3, 0, 5, 7]
./main 1200 = [5, 2, 5, 7]
./main 1250 = [3, 5, 2, 7]
./main 1300 = [1, 3, 1, 7]
./main 1350 = [1, 1, 1, 7]
./main 1400 = [1, 0, 1, 7]
./main 1450 = [2, 3, 3, 7]
Here we can observe 2 main things:
- When running the processor with increasing value of
A
, the length of output monotonically increases. So, we can use binary search to find the window where the output has the desired length. - As we increase
A
The values in output changes more frequently in left than at right, which means to see a change in the rightmost output value we will have to wait much further. So, I start with a big step size and identify the range where the rightmost value appears. Within those windows, I reduce the step size by an order of magnitude and identify the range where the second last value would appear. I do this on an on from right to left.
This way we dramatically reduce the search space to find the solution.
import subprocess
def run_machine(x):
result = subprocess.run(['./main', f'{x}'], capture_output=True, text=True)
return [int(i) for i in result.stdout.strip()[1:-1].split(',')]
def binary_search_length(target_len):
lo = 1
hi = int(1e22)
while lo + 1 < hi:
md = (lo + hi) // 2
cur_len = len(run_machine(md))
if cur_len < target_len:
lo = md
else:
hi = md
return hi
desired_output = [2,4,1,1,7,5,1,4,0,3,4,5,5,5,3,0]
def search_in(lo, hi, step, index):
print(lo, hi, step, index)
ans = []
inside = False
last_i = lo
for i in range(lo, hi, step):
result = run_machine(i)
if result[index] == desired_output[index]:
if not inside:
inside = True
elif inside:
ans.append((last_i, i))
inside = False
else:
last_i = i
if inside:
ans.append((last_i, hi))
return ans
def solve():
lo = binary_search_length(len(desired_output))
hi = binary_search_length(len(desired_output) + 1)
step = 1_000_000_000_000
possibles = [(lo, hi)]
for i in reversed(range(2, len(desired_output))):
print(i, possibles, step)
new_possibles = []
for (l, h) in possibles:
new_possibles.extend(search_in(l, h, step, i))
possibles = new_possibles
step //=10
print(i, possibles, step)
possibles = solve()
for (l, h) in possibles:
print(l, h)
for i in range(l, h):
result = run_machine(i)
if all([a == b for (a,b) in zip(desired_output, result)]):
print(f"answer found: {i}")
exit(0)
break