If you want to visualize exactly what a neural network is doing, one approach is to visualize the entire computation graph of multiplies, additions, and nonlinearities. While this is extremely complex even on MNIST, we can make it much simpler by making the networks _sparse_, since any zero weights can be removed from the computation graph. Previous work has shown that we can remove well over 95% of weights from a model without degrading accuracy too much, so the authors do this to make the computation graph easier to understand.
They use this to visualize an MLP model for classifying MNIST digits, and for a DQN agent trained to play Cartpole. In the MNIST case, the computation graph can be drastically simplified by visualizing the first layer of the net as a list of 2D images, where the kth activation is given by the dot product of the 2D image with the input image. This deals with the vast majority of the weights in the neural net.
Planned opinion:
This method has the nice property that it visualizes exactly what the neural net is doing—it isn’t “rationalizing” an explanation, or eliding potentially important details. It is possible to gain interesting insights about the model: for example, the logit for digit 2 is always −2.39, implying that everything else is computed relative to −2.39. Looking at the images for digit 7, it seems like the model strongly believes that sevens must have the top few rows of pixels be blank, which I found a bit surprising. (I chose to look at the digit 7 somewhat arbitrarily.)
Of course, since the technique doesn’t throw away any information about the model, it becomes very complicated very quickly, and wouldn’t scale to larger models.
Planned summary for the Alignment Newsletter:
Planned opinion: