If you want, you can slightly refactor my proposal to include a reporter module that takes the primary model’s hidden representations as input and outputs more interpretable representations for the student models to use. That would leave the primary model’s training objective unchanged. However, I don’t think this is a good idea for much the same reason that training just the classification head of a pretrained language model isn’t a good idea.
However, I think training the primary model to be interpretable to other systems may actually improve economic competitiveness. The worth of a given approach depends on the ratio of capabilities to compute required. If you have a primary model whose capabilities are more easily distilled into smaller models, that’s an advantage from a competitiveness standpoint. You can achieve better performance on cheaper models compared to competitors.
I think people are FAR too eager to assume a significant capabilities/interpretability tradeoff. In a previous post, I used analogies to the brain to argue that there’s enormous room to improve the interpretability of existing ML systems with little capabilities penalty.
To go even further, more interpretable internal representations may actually improve learning. ML systems face their own internal interpretability problems. To optimize a system, gradient descent needs to be able to disentangle which changes will benefit vs harm the system’s performance. This is a form of interpretability, though not one we often consider.
Being “interpretable to gradient descent” is very different from being “interpretable to humans”. However, most of my proposal focuses on making the primary model generally interpretable to many different systems, with humans as a special case. I think being more interpretable may directly lead to being easier to optimize. Intuitively, it seems easier to improve a system with simple, contained and modular internal components that interact with each other in a straightforward, consistent manner.
Imagine the primary model as being made up of a collection of interacting circuits. If those circuits have many complex, nonlinear interdependencies, then gradient descent will have issues with quickly modifying the circuits. Significant changes to one circuit will cause issues with the circuits that depend on the modified circuit for their input. Alternatively, if the circuits are more modular with robust, consistent interfaces, then it’s easier for gradient descent to make significant changes quickly, without messing up the rest of the network. This is closely related to the fact that documentation and good APIs become vastly more important as software projects increase in size.
(Note that you can view vanishing/exploding gradients as an interpretability issue preventing gradient descent from figuring out how it should change certain weights, and vanishing/exploding gradients certainly hurt performance.)
[deleted]
If you want, you can slightly refactor my proposal to include a reporter module that takes the primary model’s hidden representations as input and outputs more interpretable representations for the student models to use. That would leave the primary model’s training objective unchanged. However, I don’t think this is a good idea for much the same reason that training just the classification head of a pretrained language model isn’t a good idea.
However, I think training the primary model to be interpretable to other systems may actually improve economic competitiveness. The worth of a given approach depends on the ratio of capabilities to compute required. If you have a primary model whose capabilities are more easily distilled into smaller models, that’s an advantage from a competitiveness standpoint. You can achieve better performance on cheaper models compared to competitors.
I think people are FAR too eager to assume a significant capabilities/interpretability tradeoff. In a previous post, I used analogies to the brain to argue that there’s enormous room to improve the interpretability of existing ML systems with little capabilities penalty.
To go even further, more interpretable internal representations may actually improve learning. ML systems face their own internal interpretability problems. To optimize a system, gradient descent needs to be able to disentangle which changes will benefit vs harm the system’s performance. This is a form of interpretability, though not one we often consider.
Being “interpretable to gradient descent” is very different from being “interpretable to humans”. However, most of my proposal focuses on making the primary model generally interpretable to many different systems, with humans as a special case. I think being more interpretable may directly lead to being easier to optimize. Intuitively, it seems easier to improve a system with simple, contained and modular internal components that interact with each other in a straightforward, consistent manner.
Imagine the primary model as being made up of a collection of interacting circuits. If those circuits have many complex, nonlinear interdependencies, then gradient descent will have issues with quickly modifying the circuits. Significant changes to one circuit will cause issues with the circuits that depend on the modified circuit for their input. Alternatively, if the circuits are more modular with robust, consistent interfaces, then it’s easier for gradient descent to make significant changes quickly, without messing up the rest of the network. This is closely related to the fact that documentation and good APIs become vastly more important as software projects increase in size.
(Note that you can view vanishing/exploding gradients as an interpretability issue preventing gradient descent from figuring out how it should change certain weights, and vanishing/exploding gradients certainly hurt performance.)