I like this visualization tool. There are some very interesting things going on here when you look into the details of the network in the second-to-last MNIST figure. One is that it seems to mostly identify each digit by ruling out the others. For instance, the first two dot-product boxes (on the lower left) could be described as “not-a-0” detectors, and will give a positive result if they detect pixels in the center, near the corners, or at the extreme edges. The next two boxes could be loosely called “not-a-9“ detectors (though they also contribute to the 0 and 4 classifications) and the three after that are “not-a-4” detectors. (The use of ReLU makes this functionally different from having a “4” detector with a negative weight.)
Now, take a look at the two boxes that contribute to output 1 (I would call these “not-a-1” detectors). If you select the “7″ input and see how those two boxes respond to it, they both react pretty strongly to the very bottom of the 7 (the fact that it’s near the edge) and that’s what allows the network to distinguish a 7 from a 1. Intuitively, this seems like a shared feature—so why is the network so sure that anything near the bottom cannot be part of a 1?
It looks to me like it’s taking advantage of the way the MNIST images are preprocessed, with each digit’s center-of-mass translated to the center of the image. Because the 7 has more of its mass near the top, its lower extremity can reach farther from the center. The 1, on the other hand, is not top-heavy or bottom-heavy, so it won’t be translated by any significant amount in preprocessing and its extremities can’t get near the edges.
The same thing happens with the “not-a-3” detector box when you select input 2. The “not-a-3″ detector triggers quite strongly because of the tail that stretches out to the right. That area could never be occupied by a 3, because the 3 has most of its pixel mass near its right edge and will be translated left to get its center of mass centered in the image. The “7” detector (an exception to the pattern of ruling digits out) mostly identifies a 7 by the fact that it does not have any pixels near the top of the image (and to a lesser extent, does not have pixels in the lower-right corner).
What does this pattern tell us? First, that a different preprocessing technique (centering a bounding box in the image, for instance, instead of the digit’s center of mass) would require a very different strategy. I don’t know off hand what it would look like—maybe there’s some trick for making the problems equivalent, maybe not. Second, that it can succeed without noticing most of what humans would consider the key features of these digits. For the most part it doesn’t need to know the difference between straight lines and curved lines, or whether the parts connect the way they’re supposed to, or even whether lines are horizontal or vertical. It can use simple cues like how far each digit extends in different directions from its center of mass. Maybe with so few layers (and no convolution) it has to use those simple cues.
As far as interpretability, this seems difficult to generalize to non-visual data, since humans won’t intuitively grasp what’s going on as easily. But it certainly seems worthwhile to explore ideas for how it could work.
I like this visualization tool. There are some very interesting things going on here when you look into the details of the network in the second-to-last MNIST figure. One is that it seems to mostly identify each digit by ruling out the others. For instance, the first two dot-product boxes (on the lower left) could be described as “not-a-0” detectors, and will give a positive result if they detect pixels in the center, near the corners, or at the extreme edges. The next two boxes could be loosely called “not-a-9“ detectors (though they also contribute to the 0 and 4 classifications) and the three after that are “not-a-4” detectors. (The use of ReLU makes this functionally different from having a “4” detector with a negative weight.)
Now, take a look at the two boxes that contribute to output 1 (I would call these “not-a-1” detectors). If you select the “7″ input and see how those two boxes respond to it, they both react pretty strongly to the very bottom of the 7 (the fact that it’s near the edge) and that’s what allows the network to distinguish a 7 from a 1. Intuitively, this seems like a shared feature—so why is the network so sure that anything near the bottom cannot be part of a 1?
It looks to me like it’s taking advantage of the way the MNIST images are preprocessed, with each digit’s center-of-mass translated to the center of the image. Because the 7 has more of its mass near the top, its lower extremity can reach farther from the center. The 1, on the other hand, is not top-heavy or bottom-heavy, so it won’t be translated by any significant amount in preprocessing and its extremities can’t get near the edges.
The same thing happens with the “not-a-3” detector box when you select input 2. The “not-a-3″ detector triggers quite strongly because of the tail that stretches out to the right. That area could never be occupied by a 3, because the 3 has most of its pixel mass near its right edge and will be translated left to get its center of mass centered in the image. The “7” detector (an exception to the pattern of ruling digits out) mostly identifies a 7 by the fact that it does not have any pixels near the top of the image (and to a lesser extent, does not have pixels in the lower-right corner).
What does this pattern tell us? First, that a different preprocessing technique (centering a bounding box in the image, for instance, instead of the digit’s center of mass) would require a very different strategy. I don’t know off hand what it would look like—maybe there’s some trick for making the problems equivalent, maybe not. Second, that it can succeed without noticing most of what humans would consider the key features of these digits. For the most part it doesn’t need to know the difference between straight lines and curved lines, or whether the parts connect the way they’re supposed to, or even whether lines are horizontal or vertical. It can use simple cues like how far each digit extends in different directions from its center of mass. Maybe with so few layers (and no convolution) it has to use those simple cues.
As far as interpretability, this seems difficult to generalize to non-visual data, since humans won’t intuitively grasp what’s going on as easily. But it certainly seems worthwhile to explore ideas for how it could work.