The SAE could learn to represent the true features, A & B, as the left graph, so the orthogonal regularizer would help. When you say the SAE would learn inhibitory weights*, I’m imagining the graph on the right; however, these features are mostly orthogonal to eachother meaning the proposed solution won’t work AFAIK.
(Also, would be the regularizer be abs(cos_sim(x,x’))?)
*In this example this is because the encoder would need inhibitory weights to e.g. prevent neuron 1 from activating when both neurons 1 & 2 are present as we will discuss shortly.
If the True Features are located at: A: (0,1) B: (1,0)
[So A^B: (1,1)]
Given 3 SAE hidden-dimensions, a ReLU & bias, the model could learn 3 sparse features 1. A^~B (-1, 1) 2. A^B (1,1) 3. ~A^B(1,-1)
that output 1-hot vectors for each feature. These are also are orthogonal to each other.
Concretely:
import torch
W = torch.tensor([[-1, 1],[1,1],[1,-1]])
x = torch.tensor([[0,1], [1,1],[1,0]])
b = torch.tensor([0, -1, 0])
y = torch.nn.functional.relu(x@W.T + b)
Thanks for clarifying! Indeed the encoder weights here would be orthogonal. But I’m suggesting applying the orthogonality regularisation to the decoder weights which would not be orthogonal in this case.
The SAE could learn to represent the true features, A & B, as the left graph, so the orthogonal regularizer would help. When you say the SAE would learn inhibitory weights*, I’m imagining the graph on the right; however, these features are mostly orthogonal to eachother meaning the proposed solution won’t work AFAIK.
(Also, would be the regularizer be abs(cos_sim(x,x’))?)
yeah I was thinking abs(cos_sim(x,x’))
I’m not sure what you’re getting at regarding the inhibitory weights as the image link is broken
Thanks for saying the link is broken!
If the True Features are located at:
A: (0,1)
B: (1,0)
[So A^B: (1,1)]
Given 3 SAE hidden-dimensions, a ReLU & bias, the model could learn 3 sparse features
1. A^~B (-1, 1)
2. A^B (1,1)
3. ~A^B(1,-1)
that output 1-hot vectors for each feature. These are also are orthogonal to each other.
Concretely:
Thanks for clarifying! Indeed the encoder weights here would be orthogonal. But I’m suggesting applying the orthogonality regularisation to the decoder weights which would not be orthogonal in this case.
Ah, you’re correct. Thanks!
I’m now very interested in implementing this method.