Training deep neural networks involves navigating high-dimensional loss landscapes. Understanding the curvature of these landscapes via the Hessian of the loss function can provide insights into the optimization dynamics. However, computing the full Hessian can be prohibitively expensive. In this post, I describe a method (described by John Wentworth in his lecture series) for efficiently computing the top eigenvalues and eigenvectors of the loss Hessian using PyTorch’s autograd and SciPy’s sparse linear algebra utilities.
Hessian-vector product
The core idea hinges upon the Hessian-vector product (HVP). Given a vector v, the HVP is defined as H⋅v , where H is the Hessian matrix. This product can be computed efficiently using automatic differentiation without forming the full Hessian. The process can be outlined as:
Compute the gradient of the loss with respect to model parameters: g=∇L
Compute the dot product of g and v: c=g⋅v
Compute the gradient of c with respect to the model parameters, which gives the HVP
Lanczos Iteration and eigsh
eigsh from scipy.sparse.linalg implements the Lanczos iteration, which finds the top eigenvalues and eigenvectors of a symmetric matrix. It requires matrix-vector multiplication as the main computation, making it ideal for large matrices where full matrix factorizations are infeasible.
Using LinearOperator
To interface with eigsh, we need a mechanism to represent our Hessian as a linear operator that supports matrix-vector multiplication. SciPy’s LinearOperator serves this purpose, allowing us to define a matrix implicitly by its action on vectors without forming the matrix explicitly.
Implementation
Given a PyTorch model, loss function, and training data, the approach is to:
Accumulate a subset of the training data (as many batches as specified)
Define the HVP using PyTorch’s autograd
Construct a LinearOperator using the HVP
Call eigsh with this linear operator to compute the top eigenvalues and eigenvectors
import torch as t
from torch.autograd import grad
from scipy.sparse.linalg import LinearOperator, eigsh
import numpy as np
def get_hessian_eigenvectors(model, loss_fn, train_data_loader, num_batches, device, n_top_vectors, param_extract_fn):
"""
model: a pytorch model
loss_fn: a pytorch loss function
train_data_loader: a pytorch data loader
num_batches: number of batches to use for the hessian calculation
device: the device to use for the hessian calculation
n_top_vectors: number of top eigenvalues / eigenvectors to return
param_extract_fn: a function that takes a model and returns a list of parameters to compute the hessian with respect to (pass None to use all parameters)
returns: a tuple of (eigenvalues, eigenvectors)
eigenvalues: a numpy array of the top eigenvalues, arranged in increasing order
eigenvectors: a numpy array of the top eigenvectors, arranged in increasing order, shape (n_top_vectors, num_params)
"""
param_extract_fn = param_extract_fn or (lambda x: x.parameters())
num_params = sum(p.numel() for p in param_extract_fn(model))
subset_images, subset_labels = [], []
for batch_idx, (images, labels) in enumerate(train_data_loader):
if batch_idx >= num_batches:
break
subset_images.append(images.to(device))
subset_labels.append(labels.to(device))
subset_images = t.cat(subset_images)
subset_labels = t.cat(subset_labels)
def compute_loss():
output = model(subset_images)
return loss_fn(output, subset_labels)
def hessian_vector_product(vector):
model.zero_grad()
grad_params = grad(compute_loss(), param_extract_fn(model), create_graph=True)
flat_grad = t.cat([g.view(-1) for g in grad_params])
grad_vector_product = t.sum(flat_grad * vector)
hvp = grad(grad_vector_product, param_extract_fn(model), retain_graph=True)
return t.cat([g.contiguous().view(-1) for g in hvp])
def matvec(v):
v_tensor = t.tensor(v, dtype=t.float32, device=device)
return hessian_vector_product(v_tensor).cpu().detach().numpy()
linear_operator = LinearOperator((num_params, num_params), matvec=matvec)
eigenvalues, eigenvectors = eigsh(linear_operator, k=n_top_vectors, tol=0.001, which=‘LM’, return_eigenvectors=True)
eigenvectors = np.transpose(eigenvectors)
return eigenvalues, eigenvectors
Recipe: Hessian eigenvector computation for PyTorch models
The idea/description of this method is fully taken from John Wentworth’s Applied Linear Algebra lecture series, specifically Lecture 2.
Training deep neural networks involves navigating high-dimensional loss landscapes. Understanding the curvature of these landscapes via the Hessian of the loss function can provide insights into the optimization dynamics. However, computing the full Hessian can be prohibitively expensive. In this post, I describe a method (described by John Wentworth in his lecture series) for efficiently computing the top eigenvalues and eigenvectors of the loss Hessian using PyTorch’s autograd and SciPy’s sparse linear algebra utilities.
Hessian-vector product
The core idea hinges upon the Hessian-vector product (HVP). Given a vector v, the HVP is defined as H⋅v , where H is the Hessian matrix. This product can be computed efficiently using automatic differentiation without forming the full Hessian. The process can be outlined as:
Compute the gradient of the loss with respect to model parameters: g=∇L
Compute the dot product of g and v: c=g⋅v
Compute the gradient of c with respect to the model parameters, which gives the HVP
Lanczos Iteration and eigsh
eigsh from scipy.sparse.linalg implements the Lanczos iteration, which finds the top eigenvalues and eigenvectors of a symmetric matrix. It requires matrix-vector multiplication as the main computation, making it ideal for large matrices where full matrix factorizations are infeasible.
Using LinearOperator
To interface with eigsh, we need a mechanism to represent our Hessian as a linear operator that supports matrix-vector multiplication. SciPy’s LinearOperator serves this purpose, allowing us to define a matrix implicitly by its action on vectors without forming the matrix explicitly.
Implementation
Given a PyTorch model, loss function, and training data, the approach is to:
Accumulate a subset of the training data (as many batches as specified)
Define the HVP using PyTorch’s autograd
Construct a LinearOperator using the HVP
Call eigsh with this linear operator to compute the top eigenvalues and eigenvectors
Appendix: Python code
You can find this code as a GitHub gist here also.