Anthropic recently published the paper Studying Large Language Model Generalization with Influence Functions, which describes a scalable technique for measuring which training examples were most influential for a particular set of weights/outputs of a trained model. This can help us better understand model generalization, offering insights into the emergent properties of AI systems. For instance, influence functions could help us answer questions like “is the model asking not to be shut down because it has generalized that this is a generically good strategy for pursuing some goal, or simply because texts where AIs ask not to be shut down are commonly found in the training corpus?”.
In this post, I aim to summarize the approximations used in the paper to calculate the influence of different training examples and outline how the approximations can be implemented in PyTorch to form the basis of further research on influence functions by the AI safety community.
(Note: most formulae are copied or adapted from the original paper, with a few additional derivation steps / simplified notation used for clarity in some places.)
Deriving the exact form of the influence function
Before we go into approximations, it is necessary to understand what specifically we are trying to measure.
Given some element zm=(xm,ym) of a dataset D={zi}Ni=1, we define the response function as the optimal solution θ∗ (weights that minimize expected loss L) as a function of the weighting ϵ of this example.
θ∗(ϵ)=argminθ∈RD1N∑L(zi,θ)+ϵL(zm,θ)
We define the influenceIθ∗(zm) of zm on θ∗ using the first-order Taylor approximation to the response function at ϵ=0.
Δθ=θ∗(ϵ)−θ∗(0)≈∂θ∗(ϵ)∂ϵ|ϵ=0⋅ϵ=Iθ∗(zm)⋅ϵ
We can get ∂θ∗(ϵ)∂ϵ the following way:
We know θ∗ is a minimum of 1N∑L(zi,θ)+ϵL(zm,θ) and so the gradient wrt θ is zero at that point
∇θ(1N∑L(zi,θ∗)+ϵL(zm,θ∗))=0
Differentiating each side wrt ϵ:
(The LHS both directly depends on ϵ, and indirectly via θ∗, so we use the Implicit Function Theorem u(x)=f(x,g(x))=0→dudx=∂f∂g∂g∂x+∂f∂x)
The second term can be simplified: ∂∂ϵ∇θ(1N∑L(zi,θ∗)+ϵL(zm,θ∗))=∇θL(zm,θ∗)
And so we can rearrange to get an expression for ∂θ∗∂ϵ:
∂θ∗∂ϵ=−∇2θ(1N∑L(zi,θ∗)+ϵL(zm,θ∗))−1∇θL(zm,θ∗)
This tells us how the optimal parameters change with a perturbation ϵ to the weighting of an added data point zm. The change is proportional to the negative product of the inverse Hessian of the loss on all the data and the gradient of the loss on the data point in question with respect to the model parameters (both evaluated at the optimal parameters).
For simplicity, as in the paper, we’ll denote ∇2θ(1N∑L(zi,θ∗)+ϵL(zm,θ∗)) as H.
Therefore, Iθ∗(zm)≈−H−1∇θL(zm,θ∗) (This corresponds to Equation 3 in the paper).
Influence on some function of the model weights
So far, we have derived an expression for the influence of an added data point on the parameters θ. However, we are more interested in the influence of particular data points on some measurable properties of the model, such as the output logits or validation loss. We can see this as some function f(θ∗) of the trained parameters.
By the chain rule ∂f(θ∗(ϵ))∂ϵ=∇θf(θ∗)∂θ∗(ϵ)∂ϵ and so
If(zm)≈−∇θf(θ∗)TH−1∇θL(zm,θ∗) (This corresponds to Equation 5 in the paper).
Problems with this expression
Hessian could have zeros and be not invertible (optimal parameters could be underspecified by loss function in case of overparameterized models)
We often don’t train to convergence, so the first derivative of the loss wrt the parameters is then not necessarily zero, as previously assumed
The paper mentions that because of these problems, “past works have found influence functions to be inaccurate for modern neural networks.”
How do we fix this?
One approach is to define a new objective that:
Has a single defined optimum in parameter space
Is fully optimized when the model stops training
This is what the proximal Bregman objective (PBO) attempts to define.
(yi here is the output of the model at the parameters θ on input xi, and ysi is the output of the model at parameters θs on input xi)
The PBO basically introduces a penalty for diverging too far from the initialized parameters, so there is some defined optimum that balances moving too far from the parameters at initialization and achieving good loss.
So we can redefine the gradients used in If in terms of this new loss function that considers both the loss given a new training data point and the divergence from current parameters.
...while influence functions for neural networks are often a poor match to LOO [Leave One Out] retraining, they are a much better match to what we term the proximal Bregman response function (PBRF). Intuitively, the PBRF approximates the effect of removing a data point while trying to keep the predictions consistent with those of the (partially) trained model.
...
In addition, although the PBRF may not necessarily align with LOO retraining due to the warm-start[1], proximity, and non-convergence gaps, the motivating use cases for influence functions typically do not rely on exact LOO retraining. This means that the PBRF can be used in place of LOO retraining for many tasks such as identifying influential or mislabelled examples
Applying the Implicit Function Theorem to the PBO, we can obtain an influence function with respect to the PBO objective (Equation 9 in the paper):
Iθs(zm)=dθsdϵ=−(G+λI)−1∇θL(zm,θs)
Ifθs(zm)=−∇θf(θs)T(G+λI)−1∇θL(zm,θs)
Where G is the Gauss-Newton Hessian G=E[JTHyJ]. J is the Jacobian—the first derivative of the network’s outputs with respect to the parameters, and Hy is the Hessian of the loss with respect to the network’s outputs.
Efficient calculation
So, we want to get −∇θf(θs)T(G+λI)−1∇θL(zm,θs).
Let’s assume we have the following:
A trained network with parameters θs
An observable property of the network, f(θs), for instance, its output logits ytarget for some chosen input xtarget
The training dataset D
The loss function L the model was trained on
The key ingredients needed to calculate If are:
The gradient of the property f of interest with respect to the parameters θ, evaluated at θs
A way of getting the inverse damped Gauss-Newton Hessian vector product
The gradient of the loss L on the training data points (which we want to calculate the influence of) with respect to the parameters θ, evaluated at θs
Notice that only key ingredient 1 depends on the property of interest. We can pre-compute ingredients 2 and 3 and then use this to test a bunch of different properties (for example, find the most influential training examples for a bunch of different model input-output pairs).
We can also calculate the influence as a batched operation over many training data points (batch ∇θL(zm,θs) over multiple zm’s) to increase efficiency via vectorization.
Which leaves the final key question: how do we get (G+λI)−1v?
Kronecker-Factored Approximate Curvature (KFAC)
Originally introduced in the 2015 paper Optimizing Neural Networks with Kronecker-factored Approximate Curvature by Martens and Grosse, KFAC is an approximation to the Fischer information matrix (FIM) =Ex∼p(x),y∼P(y|x;θ)[∇θlogp(y|x;θ)∇θlogp(y|x;θ)⊤] that can be inverted very efficiently. In the case of many models where the loss is given by the negative log probability associated with a simple predictive distribution, the FIM is equal to the Gauss-Newton Hessian.
KFAC for MLP models involves the following:
Given a fully connected model with L layers, let’s assume each layer l’s output is:
al=ϕl(Wlal−1)
where al−1∈RM, Wl∈RP×M and ϕl is a nonlinear activation function. [2]
When we backpropagate logp(y|x;θ) to get ∇θlogp(y|x;θ), we need to calculate the derivative of the logp(y|x;θ) with respect to intermediate stages of the computation at each layer. So as we go backward through the computational graph, once we get to the output of Wlal−1 , we’ll have computed ∇Wlal−1logp(y|x;θ).
By the chain rule, and using the fact that ∇WlWlal−1=al−1, ∇Wllogp(y|x;θ)=∇Wlal−1logp(y|x;θ)⋅aTl−1
This means we can decompose the gradient[3] of the log-likelihood loss on some data point (x,y) with respect to the weight matrix W into the intermediate gradients of the loss with respect to the output of applying the weight matrix and the activations prior to that layer.
Working with gradients of weight matrices is inconvenient though, as we end up with 3D tensors for the Jacobian. We can instead consider θl, the unrolled weight matrix for layer l.
Then, defining Dv=∇vlogp(y|x;θ), sl=Wlal−1 , and ⊗ as the Kronecker product:
Dθl=al−1⊗Dsl
So far, so exact… But now, time for approximations. KFAC makes things simpler by assuming:
Gradients Dθl are uncorrelated between different layers
Activations al are independent of pre-activation gradients Dθl
This allows us to write down a simple block-diagonal approximation for G:
Where Al−1 and Sl are uncentered covariance matrices for the layer’s input activations and pre-nonlinearity gradients, respectively.
This structure enables us to efficiently get the inverse (approximate) Gauss-Newton Hessian vector product:
Let Vl denote the entries of v for layer l, reshaped to match Wl, and let vl=vec(Vl)
Using various Kronecker product identities, we can compute the inverse (approximate) Gauss-Newton Hessian vector product as:
^G−1lvl=vec(S−1lVlA−1l−1)
Eigenvalue correction
We made an approximation earlier when we went from E[al−1aTl−1⊗DslDsTl] to E[al−1al−1]⊗E[DslDsTl]
Using the eigendecompositions of A and S:
A=QAΛAQTA
S=QSΛSQTS
we can write a more accurate expression for G:
G≈(QA⊗QS)Λ(QA⊗QS)T
where the diagonal matrix Λ is defined as:
Λii=E[((QA⊗QS)Dθ)2i]=E[(QSvec(Dθ)QTA)2i]
which “captures the variances of the pseudo-gradient projected onto each eigenvector of the K-FAC approximation”.
We can get the damped inverse Gauss-Newton Hessian vector product approximation by adding λ to the eigenvalues, obtaining:
(G+λI)−1v≈(QA⊗QS)(Λ+λI)−1(QA⊗QS)Tv
=vec(QTS[(QSVQTA)⊘unvec(diag−1(Λ+λI))]QA)
Influence functions for autoregressive models
A few details change when we want to calculate Ifθs(zm)=−∇θf(θs)T(G+λI)−1∇θL(zm,θs) for a Transformer language model trained with an autoregressive loss function.
In this case, the property of interest f considered (the thing we are calculating the influence on) is the log-likelihood of a particular token string completion zc, given a token string prompt zp[4]:
log p(zc|zp;θ)
The paper only considers measuring the influence on a subset of the Transformer’s weights—only the MLP layers—so the MLP G approximation derived above applies almost exactly.
However, the parameter gradients are now summed over token indices:
Dθl=∑Tt=1Dθl,t=∑Tt=1al−1,t⊗Dsl,t
Each diagonal block of G is given by E[DθlDθTl], however we want to take into account how this second moment is affected by the inter-token correlations and so cannot as accurately directly approximate with E[al−1al−1]⊗E[DslDsTl] as before.
The paper presents the following middle-ground between efficiency and accuracy:
We first fit the covariance factors A and S as if the tokens were fully independent, and compute their respective eigendecompositions. Then, when fitting the diagonal matrix Λ, we use the exact pseudo-gradients Dθl which are summed over tokens. This way, at least the estimated diagonal entries of the moments in the Kronecker eigenbasis are unbiased.
Implementing in PyTorch
As described above, the key ingredients for If are:
The gradient of the property f of interest with respect to the parameters θ, evaluated at θs
The inverse damped Gauss-Newton Hessian, which we can calculate from the expectations of the following quantities:
al−1 - the MLP layer inputs
Dsl - the MLP pre-nonlinearity gradients (gradients of loss wrt output of linear transformation Wlal−1)
The gradient of the loss L on the training data points (which we want to calculate the influence of) with respect to the parameters θ, evaluated at θs
We can get 1) and 3) by simply fetching parameter.grad[5] after performing a backward pass of the loss on some input, target pair.
We can get 2a) using a forward hook that saves the input to a layer during the forward pass. We can get 2b) using a backward hook on the linear layer that saves the gradient wrt the linear layer’s output.
You can find my implementation attempt on GitHub here[6]- includes code applying influence functions analysis to a vanilla MLP trained on MNIST and a 2-layer transformer trained on a basic next character prediction task.
Results of small experiment on MNIST
I trained an MLP on MNIST (with flattened images) and then used the influence function approximation code to extract influential training examples for particular predicted test set labels.
I found that influential training digits were usually more sloppy / unclear compared to the average MNIST digit, and shared some resemblance with the query image. Not all top influential images shared the same label as the query. I only searched a subset of the training corpus, for efficiency.
Here are some examples, filtered by cases where the influence was non-negligible (some queries returned ~0 for all sampled training datapoints) (first image on left is query, followed by top most influential training images given by the approximation):
The warm-start problem referenced by Bae et al. refers to the fact that for a not strictly convex objective, the influence of a training example in the neighborhood of a minimum θ∗ may be different from the influence at a different initialization point.
The paper uses homogeneous vector notation to account for biases / affine transformations—you can assume there is a 1 appended to the activations a and a bias vector appended to W to cover this case.
The paper refers to these as “pseudo-gradients” since they are sampled from the final output distribution and are distinct from gradients during training.
The zp, zc pair is referred to as the “query” in the paper, as we are “querying” which training examples were most influential for the model producing zc given zp.
If you look through the code and find any bugs (quite possible) or performance improvements (definitely findable; e.g. more batching + splitting of GPU ops—WIP) I’d be super happy to merge PRs and/or hear from you! I hope to gradually improve this codebase and run larger experiments.
Influence functions—why, what and how
Anthropic recently published the paper Studying Large Language Model Generalization with Influence Functions, which describes a scalable technique for measuring which training examples were most influential for a particular set of weights/outputs of a trained model. This can help us better understand model generalization, offering insights into the emergent properties of AI systems. For instance, influence functions could help us answer questions like “is the model asking not to be shut down because it has generalized that this is a generically good strategy for pursuing some goal, or simply because texts where AIs ask not to be shut down are commonly found in the training corpus?”.
In this post, I aim to summarize the approximations used in the paper to calculate the influence of different training examples and outline how the approximations can be implemented in PyTorch to form the basis of further research on influence functions by the AI safety community.
(Note: most formulae are copied or adapted from the original paper, with a few additional derivation steps / simplified notation used for clarity in some places.)
Deriving the exact form of the influence function
Before we go into approximations, it is necessary to understand what specifically we are trying to measure.
Given some element zm=(xm,ym) of a dataset D={zi}Ni=1, we define the response function as the optimal solution θ∗ (weights that minimize expected loss L) as a function of the weighting ϵ of this example.
θ∗(ϵ)=argminθ∈RD1N∑L(zi,θ)+ϵL(zm,θ)
We define the influence Iθ∗(zm) of zm on θ∗ using the first-order Taylor approximation to the response function at ϵ=0.
Δθ=θ∗(ϵ)−θ∗(0)≈∂θ∗(ϵ)∂ϵ|ϵ=0⋅ϵ=Iθ∗(zm)⋅ϵ
We can get ∂θ∗(ϵ)∂ϵ the following way:
We know θ∗ is a minimum of 1N∑L(zi,θ)+ϵL(zm,θ) and so the gradient wrt θ is zero at that point
∇θ(1N∑L(zi,θ∗)+ϵL(zm,θ∗))=0
Differentiating each side wrt ϵ:
(The LHS both directly depends on ϵ, and indirectly via θ∗, so we use the Implicit Function Theorem u(x)=f(x,g(x))=0→ dudx=∂f∂g∂g∂x+∂f∂x)
∇2θ(1N∑L(zi,θ∗)+ϵL(zm,θ∗))⋅∂θ∗∂ϵ+∂∂ϵ∇θ(1N∑L(zi,θ∗)+ϵL(zm,θ∗))=0
The second term can be simplified: ∂∂ϵ∇θ(1N∑L(zi,θ∗)+ϵL(zm,θ∗))=∇θL(zm,θ∗)
And so we can rearrange to get an expression for ∂θ∗∂ϵ:
∂θ∗∂ϵ=−∇2θ(1N∑L(zi,θ∗)+ϵL(zm,θ∗))−1∇θL(zm,θ∗)
This tells us how the optimal parameters change with a perturbation ϵ to the weighting of an added data point zm. The change is proportional to the negative product of the inverse Hessian of the loss on all the data and the gradient of the loss on the data point in question with respect to the model parameters (both evaluated at the optimal parameters).
For simplicity, as in the paper, we’ll denote ∇2θ(1N∑L(zi,θ∗)+ϵL(zm,θ∗)) as H.
Therefore, Iθ∗(zm)≈−H−1∇θL(zm,θ∗) (This corresponds to Equation 3 in the paper).
Influence on some function of the model weights
So far, we have derived an expression for the influence of an added data point on the parameters θ. However, we are more interested in the influence of particular data points on some measurable properties of the model, such as the output logits or validation loss. We can see this as some function f(θ∗) of the trained parameters.
By the chain rule ∂f(θ∗(ϵ))∂ϵ=∇θf(θ∗)∂θ∗(ϵ)∂ϵ and so
If(zm)≈−∇θf(θ∗)TH−1∇θL(zm,θ∗) (This corresponds to Equation 5 in the paper).
Problems with this expression
Hessian could have zeros and be not invertible (optimal parameters could be underspecified by loss function in case of overparameterized models)
We often don’t train to convergence, so the first derivative of the loss wrt the parameters is then not necessarily zero, as previously assumed
The paper mentions that because of these problems, “past works have found influence functions to be inaccurate for modern neural networks.”
How do we fix this?
One approach is to define a new objective that:
Has a single defined optimum in parameter space
Is fully optimized when the model stops training
This is what the proximal Bregman objective (PBO) attempts to define.
θs(ϵ)=argminθ∈RD1N∑(L(zi,θ)−L(zi,θs)−∇yiL(zi,θs)T(yi−ysi))
+ϵL(zm,θ)+λ2||θ−θs||2
(yi here is the output of the model at the parameters θ on input xi, and ysi is the output of the model at parameters θs on input xi)
The PBO basically introduces a penalty for diverging too far from the initialized parameters, so there is some defined optimum that balances moving too far from the parameters at initialization and achieving good loss.
So we can redefine the gradients used in If in terms of this new loss function that considers both the loss given a new training data point and the divergence from current parameters.
From Bae et al.’s 2022 paper If Influence Functions are the Answer, Then What is the Question?:
Applying the Implicit Function Theorem to the PBO, we can obtain an influence function with respect to the PBO objective (Equation 9 in the paper):
Iθs(zm)=dθsdϵ=−(G+λI)−1∇θL(zm,θs)
Ifθs(zm)=−∇θf(θs)T(G+λI)−1∇θL(zm,θs)
Where G is the Gauss-Newton Hessian G=E[JTHyJ]. J is the Jacobian—the first derivative of the network’s outputs with respect to the parameters, and Hy is the Hessian of the loss with respect to the network’s outputs.
Efficient calculation
So, we want to get −∇θf(θs)T(G+λI)−1∇θL(zm,θs).
Let’s assume we have the following:
A trained network with parameters θs
An observable property of the network, f(θs), for instance, its output logits ytarget for some chosen input xtarget
The training dataset D
The loss function L the model was trained on
The key ingredients needed to calculate If are:
The gradient of the property f of interest with respect to the parameters θ, evaluated at θs
A way of getting the inverse damped Gauss-Newton Hessian vector product
The gradient of the loss L on the training data points (which we want to calculate the influence of) with respect to the parameters θ, evaluated at θs
Notice that only key ingredient 1 depends on the property of interest. We can pre-compute ingredients 2 and 3 and then use this to test a bunch of different properties (for example, find the most influential training examples for a bunch of different model input-output pairs).
We can also calculate the influence as a batched operation over many training data points (batch ∇θL(zm,θs) over multiple zm’s) to increase efficiency via vectorization.
Which leaves the final key question: how do we get (G+λI)−1v?
Kronecker-Factored Approximate Curvature (KFAC)
Originally introduced in the 2015 paper Optimizing Neural Networks with Kronecker-factored Approximate Curvature by Martens and Grosse, KFAC is an approximation to the Fischer information matrix (FIM) =Ex∼p(x),y∼P(y|x;θ)[∇θlogp(y|x;θ)∇θlogp(y|x;θ)⊤] that can be inverted very efficiently. In the case of many models where the loss is given by the negative log probability associated with a simple predictive distribution, the FIM is equal to the Gauss-Newton Hessian.
KFAC for MLP models involves the following:
Given a fully connected model with L layers, let’s assume each layer l’s output is:
al=ϕl(Wlal−1)
where al−1∈RM, Wl∈RP×M and ϕl is a nonlinear activation function. [2]
When we backpropagate logp(y|x;θ) to get ∇θlogp(y|x;θ), we need to calculate the derivative of the logp(y|x;θ) with respect to intermediate stages of the computation at each layer. So as we go backward through the computational graph, once we get to the output of Wlal−1 , we’ll have computed ∇Wlal−1logp(y|x;θ).
By the chain rule, and using the fact that ∇WlWlal−1=al−1, ∇Wllogp(y|x;θ)=∇Wlal−1logp(y|x;θ)⋅aTl−1
This means we can decompose the gradient[3] of the log-likelihood loss on some data point (x,y) with respect to the weight matrix W into the intermediate gradients of the loss with respect to the output of applying the weight matrix and the activations prior to that layer.
Working with gradients of weight matrices is inconvenient though, as we end up with 3D tensors for the Jacobian. We can instead consider θl, the unrolled weight matrix for layer l.
Then, defining Dv=∇vlogp(y|x;θ), sl=Wlal−1 , and ⊗ as the Kronecker product:
Dθl=al−1⊗Dsl
So far, so exact… But now, time for approximations. KFAC makes things simpler by assuming:
Gradients Dθl are uncorrelated between different layers
Activations al are independent of pre-activation gradients Dθl
This allows us to write down a simple block-diagonal approximation for G:
Gl=E[DθlDθTl]=E[al−1aTl−1⊗DslDsTl]≈E[al−1al−1]⊗E[DslDsTl]=Al−1⊗Sl
Where Al−1 and Sl are uncentered covariance matrices for the layer’s input activations and pre-nonlinearity gradients, respectively.
This structure enables us to efficiently get the inverse (approximate) Gauss-Newton Hessian vector product:
Let Vl denote the entries of v for layer l, reshaped to match Wl, and let vl=vec(Vl)
Using various Kronecker product identities, we can compute the inverse (approximate) Gauss-Newton Hessian vector product as:
^G−1lvl=vec(S−1lVlA−1l−1)
Eigenvalue correction
We made an approximation earlier when we went from E[al−1aTl−1⊗DslDsTl] to E[al−1al−1]⊗E[DslDsTl]
Using the eigendecompositions of A and S:
A=QAΛAQTA
S=QSΛSQTS
we can write a more accurate expression for G:
G≈(QA⊗QS)Λ(QA⊗QS)T
where the diagonal matrix Λ is defined as:
Λii=E[((QA⊗QS)Dθ)2i]=E[(QSvec(Dθ)QTA)2i]
which “captures the variances of the pseudo-gradient projected onto each eigenvector of the K-FAC approximation”.
We can get the damped inverse Gauss-Newton Hessian vector product approximation by adding λ to the eigenvalues, obtaining:
(G+λI)−1v≈(QA⊗QS)(Λ+λI)−1(QA⊗QS)Tv
=vec(QTS[(QSVQTA)⊘unvec(diag−1(Λ+λI))]QA)
Influence functions for autoregressive models
A few details change when we want to calculate Ifθs(zm)=−∇θf(θs)T(G+λI)−1∇θL(zm,θs) for a Transformer language model trained with an autoregressive loss function.
In this case, the property of interest f considered (the thing we are calculating the influence on) is the log-likelihood of a particular token string completion zc, given a token string prompt zp[4]:
log p(zc|zp;θ)
The paper only considers measuring the influence on a subset of the Transformer’s weights—only the MLP layers—so the MLP G approximation derived above applies almost exactly.
However, the parameter gradients are now summed over token indices:
Dθl=∑Tt=1Dθl,t=∑Tt=1al−1,t⊗Dsl,t
Each diagonal block of G is given by E[DθlDθTl], however we want to take into account how this second moment is affected by the inter-token correlations and so cannot as accurately directly approximate with E[al−1al−1]⊗E[DslDsTl] as before.
The paper presents the following middle-ground between efficiency and accuracy:
Implementing in PyTorch
As described above, the key ingredients for If are:
The gradient of the property f of interest with respect to the parameters θ, evaluated at θs
The inverse damped Gauss-Newton Hessian, which we can calculate from the expectations of the following quantities:
al−1 - the MLP layer inputs
Dsl - the MLP pre-nonlinearity gradients (gradients of loss wrt output of linear transformation Wlal−1)
The gradient of the loss L on the training data points (which we want to calculate the influence of) with respect to the parameters θ, evaluated at θs
We can get 1) and 3) by simply fetching
parameter.grad
[5] after performing a backward pass of the loss on some input, target pair.We can get 2a) using a forward hook that saves the input to a layer during the forward pass. We can get 2b) using a backward hook on the linear layer that saves the gradient wrt the linear layer’s output.
You can find my implementation attempt on GitHub here [6]- includes code applying influence functions analysis to a vanilla MLP trained on MNIST and a 2-layer transformer trained on a basic next character prediction task.
Results of small experiment on MNIST
I trained an MLP on MNIST (with flattened images) and then used the influence function approximation code to extract influential training examples for particular predicted test set labels.
I found that influential training digits were usually more sloppy / unclear compared to the average MNIST digit, and shared some resemblance with the query image. Not all top influential images shared the same label as the query. I only searched a subset of the training corpus, for efficiency.
Here are some examples, filtered by cases where the influence was non-negligible (some queries returned ~0 for all sampled training datapoints) (first image on left is query, followed by top most influential training images given by the approximation):
The warm-start problem referenced by Bae et al. refers to the fact that for a not strictly convex objective, the influence of a training example in the neighborhood of a minimum θ∗ may be different from the influence at a different initialization point.
The paper uses homogeneous vector notation to account for biases / affine transformations—you can assume there is a 1 appended to the activations a and a bias vector appended to W to cover this case.
The paper refers to these as “pseudo-gradients” since they are sampled from the final output distribution and are distinct from gradients during training.
The zp, zc pair is referred to as the “query” in the paper, as we are “querying” which training examples were most influential for the model producing zc given zp.
Specifically, concatenate a linear layer’s
.weight
and.bias
grad
sIf you look through the code and find any bugs (quite possible) or performance improvements (definitely findable; e.g. more batching + splitting of GPU ops—WIP) I’d be super happy to merge PRs and/or hear from you! I hope to gradually improve this codebase and run larger experiments.