Hmm, I think I don’t fully understand your post. Let me summarize what I get, and what is confusing me:
I absolutely get the “there are different levels / scales of explaining a network” point
It also makes sense to tie this to some level of loss. E.g. explain GPT-2 to a loss level of L=3.0 (rather than L=2.9), or explain IOI with 95% accuracy.
I’m also a fan of expressing losses in terms of compute or model size (“SAE on Claude 5 recovers Claude 2-levels of performance”).
I’m confused whether your post tries to tell us (how to determine) what loss our interpretation should recover, or whether you’re describing how to measure whether our interpretation recovers that loss (via constructing the M_c models).
You first introduce the SLT argument that tells us which loss scale to choose (the “Watanabe scale”, derived from the Watanabe critical temperature).
And then a second (?) scale, the “natural” scale. That loss scale is the different between the given model (Claude 2), and a hypothetical near-perfect model (Claude 5).
I’m confused how these two scales interact—are these just 2 separate things you wanted to discuss, or is there a connection I’m missing
Regarding the natural scale specifically: If Claude 5 got a CE loss of 0.5, and Claude 2 got a CE loss of 3.5, are you saying we should explain only the part/circuits of Claude 2 that are required to get a loss of 6.5 (“degrading a model by [...] its absolute loss gap”)?
Then there’s the second part, where you discuss how to obtain a model M_c* corresonding to a desired loss L_c*. There’s many ways to do this (trivially: Just walk a straight line in parameter space until the loss reaches the desired level) but you suggest a specific one (Langevin SGD). You suggest that one because it produces a model implementing a “maximally general algorithm” [1] (with the desired loss, and in the same basin). This makes sense if I were trying to interpret / reverse engineer / decompose M_c*, but I’m running my interpretability technique on M_c, right? I believe I have missed why we bother with creating the intermediate M_c model. (I assume it’s not merely to find the equivalent parameter count / Claude generation.)
[1] Regarding the “maximally general” claim: You have made a good argument that generalization to memorization is a spectrum (e.g. knowing which city is where on the globe, memorizing grammar roles, all seem kinda ambiguous). So “maximally general” seems not uniquely defined (e.g. a model that has some really general and some really memorized circuits, vs a model that has lots of middle-spectrum circuits).
You first introduce the SLT argument that tells us which loss scale to choose (the “Watanabe scale”, derived from the Watanabe critical temperature).
Sorry, I think the context of the Watanabe scale is a bit confusing. I’m saying that in fact it’s the wrong scale to use as a “natural scale”. The Watanabe scale depends only on the number of training datapoints, and doesn’t notice any other properties of your NN or your phenomenon of interest.
Roughly, the Watanabe scale is the scale on which loss improves if you memorize a single datapoint (so memorizing improves accuracy by 1/n with n = #(training set) and in a suitable operationalization, improves loss by O(logn/n), and this is the Watanabe scale).
It’s used in SLT roughly because it’s the minimal temperature scale where “memorization doesn’t count as relevant”, and so relevant measurements become independent of the n-point sample. However in most interp experiments, the realistic loss reconstruction loss reconstruction is much rougher (i.e., further from optimal loss) than the 1/n scale where memorization becomes an issue (even if you conceptualize #(training set) as some small synthetic training set that you were running the experiment on).
For your second question: again, what I wrote is confusing and I really want to rewrite it more clearly later. I tried to clarify what I think you’re asking about in this shortform. Roughly, the point here is that to avoid having your results messed up by spurious behaviors, you might want to degrade as much as possible while still observing the effect of your experiment. The idea is that if you found any degradation that wasn’t explicitly designed with your experiment in mind (i.e., is natural), but where you see your experimental results hold, then you have “found a phenomenon”. The hope is that if you look at the roughest such scale, you might kill enough confounders and interactions to make your result be “clean” (or at least cleaner): so for example optimistically you might hope to explain all the loss of the degraded model at the degradation scale you chose (whereas at other scales, there are a bunch of other effects improving the loss on the dataset you’re looking at that you’re not capturing in the explanation).
The question now is when degrading, what order you want to “kill confounders” in to optimally purify the effect you’re considering. The “natural degradation” idea seems like a good place to look since it kills the “small but annoying” confounders: things like memorization, weird specific connotations of the test sentences you used for your experiment, etc. Another reasonable place to look is training checkpoints, as these correspond to killing “hard to learn” effects. Ideally you’d perform several kinds of degradation to “maximally purify” your effect. Here the “natural scales” (loss on the level Claude 1 e.g., or Bert) are much too fine for most modern experiments, and I’m envisioning something much rougher.
The intuition here comes from physics. Like if you want to study properties of a hydrogen atom that you don’t see either in water or in hydrogen gas, a natural thing to do is to heat up hydrogen gas to extreme temperatures where the molecules degrade but the atoms are still present, now in “pure” form. Of course not all phenomena can be purified in this way (some are confounded by effects both at higher and at lower temperature, etc.).
Hmm, I think I don’t fully understand your post. Let me summarize what I get, and what is confusing me:
I absolutely get the “there are different levels / scales of explaining a network” point
It also makes sense to tie this to some level of loss. E.g. explain GPT-2 to a loss level of L=3.0 (rather than L=2.9), or explain IOI with 95% accuracy.
I’m also a fan of expressing losses in terms of compute or model size (“SAE on Claude 5 recovers Claude 2-levels of performance”).
I’m confused whether your post tries to tell us (how to determine) what loss our interpretation should recover, or whether you’re describing how to measure whether our interpretation recovers that loss (via constructing the M_c models).
You first introduce the SLT argument that tells us which loss scale to choose (the “Watanabe scale”, derived from the Watanabe critical temperature).
And then a second (?) scale, the “natural” scale. That loss scale is the different between the given model (Claude 2), and a hypothetical near-perfect model (Claude 5).
I’m confused how these two scales interact—are these just 2 separate things you wanted to discuss, or is there a connection I’m missing
Regarding the natural scale specifically: If Claude 5 got a CE loss of 0.5, and Claude 2 got a CE loss of 3.5, are you saying we should explain only the part/circuits of Claude 2 that are required to get a loss of 6.5 (“degrading a model by [...] its absolute loss gap”)?
Then there’s the second part, where you discuss how to obtain a model M_c* corresonding to a desired loss L_c*. There’s many ways to do this (trivially: Just walk a straight line in parameter space until the loss reaches the desired level) but you suggest a specific one (Langevin SGD). You suggest that one because it produces a model implementing a “maximally general algorithm” [1] (with the desired loss, and in the same basin). This makes sense if I were trying to interpret / reverse engineer / decompose M_c*, but I’m running my interpretability technique on M_c, right? I believe I have missed why we bother with creating the intermediate M_c model. (I assume it’s not merely to find the equivalent parameter count / Claude generation.)
[1] Regarding the “maximally general” claim: You have made a good argument that generalization to memorization is a spectrum (e.g. knowing which city is where on the globe, memorizing grammar roles, all seem kinda ambiguous). So “maximally general” seems not uniquely defined (e.g. a model that has some really general and some really memorized circuits, vs a model that has lots of middle-spectrum circuits).
Thanks for the questions!
Sorry, I think the context of the Watanabe scale is a bit confusing. I’m saying that in fact it’s the wrong scale to use as a “natural scale”. The Watanabe scale depends only on the number of training datapoints, and doesn’t notice any other properties of your NN or your phenomenon of interest.
Roughly, the Watanabe scale is the scale on which loss improves if you memorize a single datapoint (so memorizing improves accuracy by 1/n with n = #(training set) and in a suitable operationalization, improves loss by O(logn/n), and this is the Watanabe scale).
It’s used in SLT roughly because it’s the minimal temperature scale where “memorization doesn’t count as relevant”, and so relevant measurements become independent of the n-point sample. However in most interp experiments, the realistic loss reconstruction loss reconstruction is much rougher (i.e., further from optimal loss) than the 1/n scale where memorization becomes an issue (even if you conceptualize #(training set) as some small synthetic training set that you were running the experiment on).
For your second question: again, what I wrote is confusing and I really want to rewrite it more clearly later. I tried to clarify what I think you’re asking about in this shortform. Roughly, the point here is that to avoid having your results messed up by spurious behaviors, you might want to degrade as much as possible while still observing the effect of your experiment. The idea is that if you found any degradation that wasn’t explicitly designed with your experiment in mind (i.e., is natural), but where you see your experimental results hold, then you have “found a phenomenon”. The hope is that if you look at the roughest such scale, you might kill enough confounders and interactions to make your result be “clean” (or at least cleaner): so for example optimistically you might hope to explain all the loss of the degraded model at the degradation scale you chose (whereas at other scales, there are a bunch of other effects improving the loss on the dataset you’re looking at that you’re not capturing in the explanation).
The question now is when degrading, what order you want to “kill confounders” in to optimally purify the effect you’re considering. The “natural degradation” idea seems like a good place to look since it kills the “small but annoying” confounders: things like memorization, weird specific connotations of the test sentences you used for your experiment, etc. Another reasonable place to look is training checkpoints, as these correspond to killing “hard to learn” effects. Ideally you’d perform several kinds of degradation to “maximally purify” your effect. Here the “natural scales” (loss on the level Claude 1 e.g., or Bert) are much too fine for most modern experiments, and I’m envisioning something much rougher.
The intuition here comes from physics. Like if you want to study properties of a hydrogen atom that you don’t see either in water or in hydrogen gas, a natural thing to do is to heat up hydrogen gas to extreme temperatures where the molecules degrade but the atoms are still present, now in “pure” form. Of course not all phenomena can be purified in this way (some are confounded by effects both at higher and at lower temperature, etc.).