I'm getting close indeed! I did take a big detour into tropical geometry... The key approach at the moment is deriving a sequence D[t] of which ReLU are active during the sequence - this turns out to be quite stable after an impulse. Then we can mask W_hh via D[t] and combine all the (now-linear) steps to get an effective linear operator for the whole sequence, which we can then investigate with normal linear methods
The readout mechanism for S (2nd max) in the presence of M (max) combines two computations in a shared low-dimensional subspace
The hidden state follows a spiral trajectory through time, implemented by a rotating phase in the hidden state. The W_out projection converts phase angle to position logits. The main spiral shape does not differ between forward (M first) and reverse (S first) cases.
The network must discriminate between the very similar forward and reverse cases. The final hidden states differ by an offset:
h_forward =...
I likewise got nerd-sniped into taking this one on! It's been good fun to work on.
My current description of the circuit behaviour is pretty lengthy and has a fair amount of hand waving, so I need to work on reaching a more compact description of what is going on.
Some notes:
Zeroing out all the inputs except the largest two gets the network to 100% and made it a lot easier to see behaviour of some of the oscillatory sub-circuits.
Zeroing out everything except the max helps by showing the impulse-response behaviour.
Almost all ablations hurt t...
Now, I'm not sure I've exactly followed the brief, but I think there is some interesting stuff here: https://gist.github.com/mrsirrisrm/d6850ff8647d1ed2f67cc92d5bce3ed0
If we focus on the compute_final_state func:
with known D sequences, the RNN dynamics are piecewise-linear.
The final state is computed as:
pre = first_val * w_first[gap] + second_val * W_ih
h_second = D_second_mask * max(0, pre)
h_final = Phi_post[(gap,dir)][steps_after] ... (read more)