Give a model a graph with two distinguished vertices s and t. Train it to estimate the length of the shortest path between them, d(s, t). Do this on graphs of size 1 to N.
Fine-tune the model to output d(s, u) for an arbitrary input vertex u that are on the unique shortest path from s to t. Hopefully this is much faster. Do this for graphs of size 1 to n where n << N.
Check whether the d(s, u) head generalizes to longer graphs. If it doesn’t, try to understand what it does instead and maybe try messing around with things like simple consistency conditions.
(Shortest path was kind of arbitrary, there are probably better tasks to do. The main key thing is that there are some intermediate results that you could train the model to output.)
(The result may depend on architecture. I’m imagining something like: give each vertex a unique id from [1...N], presenting the graph in edge list format, and using an encoder->decoder architecture where the d(s, u) head gets given u as input and can attend to anything from the encoder that saw the edge-list.)
Here’s a synthetic version of this experiment:
Give a model a graph with two distinguished vertices s and t. Train it to estimate the length of the shortest path between them, d(s, t). Do this on graphs of size 1 to N.
Fine-tune the model to output d(s, u) for an arbitrary input vertex u that are on the unique shortest path from s to t. Hopefully this is much faster. Do this for graphs of size 1 to n where n << N.
Check whether the d(s, u) head generalizes to longer graphs. If it doesn’t, try to understand what it does instead and maybe try messing around with things like simple consistency conditions.
(Shortest path was kind of arbitrary, there are probably better tasks to do. The main key thing is that there are some intermediate results that you could train the model to output.)
(The result may depend on architecture. I’m imagining something like: give each vertex a unique id from [1...N], presenting the graph in edge list format, and using an encoder->decoder architecture where the d(s, u) head gets given u as input and can attend to anything from the encoder that saw the edge-list.)