We think this sort of approach can be applied layer-by-layer. As long as you know what the features are you can calculate dL/dC_i for each feature and figure out what’s going on with that. The main challenge to this is feature identification: in a one layer model with synthetic data it’s often easy to know what the features are. In more complicated settings it’s much less clear what the “right” or “natural” features are...
Although it’s not easy to determine the full set of “natural” features for arbitrary networks, still you might be able to solve an optimization problem that identifies the single feature with most negative marginal returns to capacity given the weights of some particular trained network. If you could do this then perhaps you could apply a regularization to the network that “flattens out” the marginal returns curve for just that one feature, then apply further training to the network and ask again which single feature has most negative marginal returns to capacity given the updated network weights, and again flatten out the marginal returns curve for that one feature, and repeat until there are no features with negative marginal returns to capacity. Doing this feature-by-feature would be too slow for anything but toy networks, I suppose, but if it worked for toy networks then perhaps it would point the way towards something more scalable.
Suppose instead you can find the least important (lowest absolute value of dL/dC_i) feature given some particular set of weights for a network and mask that feature out from all the inputs, and the iterate in the same way as above. In the third figure from the top in your post—the one with the big vertical stack of marginal return curves—you would be chopping off the features one-by-one from bottom to top, ideally until you have exactly as many features as you can “fit” monosemantically into a particular architecture. I suppose again that doing this feature-by-feature for anything but a toy model would be prohibitive, but perhaps there is a way to do it more efficiently. I wonder whether there is any way to “find the least important feature” and to “mask it out”.
You write down an optimization problem over (say) linear combinations of image pixels, minimizing some measure of marginal returns to capacity given current network parameters (first idea) or overall importance as measured by absolute value of dL/dC_i, again given current network parameters (second idea). By looking just for the feature that is currently “most problematic” you may be able to sidestep the need to identify the full set of “features” (whatever that really means).
I don’t know how exactly you would formulate these objective functions but it seems do-able no?
Oh I see! Sorry I didn’t realize you were describing a process for picking features.
I think this is a good idea to try, though I do have a concern. My worry is that if you do this on a model where you know what the features actually are, what happens is that this procedure discovers some heavily polysemantic “feature” that makes better use of capacity than any of the actual features in the problem. Because dL/dC_i is not a linear function of the feature’s embedding vector, there can exist superpositions of features which have greater dL/dC_i than any feature.
Anyway, I think this is a good thing to try and encourage someone to do so! I’m happy to offer guidance/feedback/chat with people interested in pursuing this, as automated feature identification seems like a really useful thing to have even if it turns out to be really expensive.
We think this sort of approach can be applied layer-by-layer. As long as you know what the features are you can calculate dL/dC_i for each feature and figure out what’s going on with that. The main challenge to this is feature identification: in a one layer model with synthetic data it’s often easy to know what the features are. In more complicated settings it’s much less clear what the “right” or “natural” features are...
Right! Two quick ideas:
Although it’s not easy to determine the full set of “natural” features for arbitrary networks, still you might be able to solve an optimization problem that identifies the single feature with most negative marginal returns to capacity given the weights of some particular trained network. If you could do this then perhaps you could apply a regularization to the network that “flattens out” the marginal returns curve for just that one feature, then apply further training to the network and ask again which single feature has most negative marginal returns to capacity given the updated network weights, and again flatten out the marginal returns curve for that one feature, and repeat until there are no features with negative marginal returns to capacity. Doing this feature-by-feature would be too slow for anything but toy networks, I suppose, but if it worked for toy networks then perhaps it would point the way towards something more scalable.
Suppose instead you can find the least important (lowest absolute value of dL/dC_i) feature given some particular set of weights for a network and mask that feature out from all the inputs, and the iterate in the same way as above. In the third figure from the top in your post—the one with the big vertical stack of marginal return curves—you would be chopping off the features one-by-one from bottom to top, ideally until you have exactly as many features as you can “fit” monosemantically into a particular architecture. I suppose again that doing this feature-by-feature for anything but a toy model would be prohibitive, but perhaps there is a way to do it more efficiently. I wonder whether there is any way to “find the least important feature” and to “mask it out”.
In both ideas I’m not sure how you’re identifying features. Manual interpretability work on a (more complicated) toy model?
You write down an optimization problem over (say) linear combinations of image pixels, minimizing some measure of marginal returns to capacity given current network parameters (first idea) or overall importance as measured by absolute value of dL/dC_i, again given current network parameters (second idea). By looking just for the feature that is currently “most problematic” you may be able to sidestep the need to identify the full set of “features” (whatever that really means).
I don’t know how exactly you would formulate these objective functions but it seems do-able no?
Oh I see! Sorry I didn’t realize you were describing a process for picking features.
I think this is a good idea to try, though I do have a concern. My worry is that if you do this on a model where you know what the features actually are, what happens is that this procedure discovers some heavily polysemantic “feature” that makes better use of capacity than any of the actual features in the problem. Because dL/dC_i is not a linear function of the feature’s embedding vector, there can exist superpositions of features which have greater dL/dC_i than any feature.
Anyway, I think this is a good thing to try and encourage someone to do so! I’m happy to offer guidance/feedback/chat with people interested in pursuing this, as automated feature identification seems like a really useful thing to have even if it turns out to be really expensive.