From our implementation in the notebook, we can train a 3 layer ReLU network on 5 datapoints, and it tends to land on a function that looks something like this:
I was curious if the NTK prior would predict the two slightly odd bumps on either side of 0 on x-axis. I usually think of neural networks as doing linear interpolation when trained on tiny 1-d datasets like this, and this shape doesn’t really fit that story. I tested this by adding three possible test datapoints and finding the log-prob of each. As you can see, the blue one has the lowest negative log-prob, so the NTK does predict that a higher data point is more likely at that location:
Unfortunately, if I push the blue test point even higher, it gets even more probable, until around y=3:
I’m confused by this. If the NTK predicts y=3 as the most likely, why doesn’t the trained neural network have a big spike there?
Another test at y=0.7 to see if the other bump is predicted gives us a really weird result:
The yellow test point is by far the most a priori likely, which just seems wrong, considering the bump in the nn function is in the other direction.
Before publishing this post I’d only tested a few of these, and the results seemed to fit well with what I expected (classic). Now that I’ve tested more test points, I’m deeply confused by these results, and they make me think I’ve misunderstood something in the theory or implementation.
From our implementation in the notebook, we can train a 3 layer ReLU network on 5 datapoints, and it tends to land on a function that looks something like this:
I was curious if the NTK prior would predict the two slightly odd bumps on either side of 0 on x-axis. I usually think of neural networks as doing linear interpolation when trained on tiny 1-d datasets like this, and this shape doesn’t really fit that story.
I tested this by adding three possible test datapoints and finding the log-prob of each. As you can see, the blue one has the lowest negative log-prob, so the NTK does predict that a higher data point is more likely at that location:
Unfortunately, if I push the blue test point even higher, it gets even more probable, until around y=3:
I’m confused by this. If the NTK predicts y=3 as the most likely, why doesn’t the trained neural network have a big spike there?
Another test at y=0.7 to see if the other bump is predicted gives us a really weird result:
The yellow test point is by far the most a priori likely, which just seems wrong, considering the bump in the nn function is in the other direction.
Before publishing this post I’d only tested a few of these, and the results seemed to fit well with what I expected (classic). Now that I’ve tested more test points, I’m deeply confused by these results, and they make me think I’ve misunderstood something in the theory or implementation.