Forward pass cost (R bits) = c * N, and assume R = Ω(1) on average
Now, thinking purely information-theoretically: Model stealing compute = C * fp16 * N / R ~ const. * c * N^2
If compute-optimal training and α = β in Chinchilla scaling law: Model stealing compute ~ Training compute
For significantly overtrained models: Model stealing << Training compute
Typically: Total inference compute ~ Training compute => Model stealing << Total inference compute
Caveats: - Prior on weights reduces stealing compute, same if you only want to recover some information about the model (e.g. to create an equally capable one) - Of course, if the model is producing much fewer than 1 token per forward pass, then model stealing compute is very large
N = #params, D = #data
Training compute = const .* N * D
Forward pass cost (R bits) = c * N, and assume R = Ω(1) on average
Now, thinking purely information-theoretically:
Model stealing compute = C * fp16 * N / R ~ const. * c * N^2
If compute-optimal training and α = β in Chinchilla scaling law:
Model stealing compute ~ Training compute
For significantly overtrained models:
Model stealing << Training compute
Typically:
Total inference compute ~ Training compute
=> Model stealing << Total inference compute
Caveats:
- Prior on weights reduces stealing compute, same if you only want to recover some information about the model (e.g. to create an equally capable one)
- Of course, if the model is producing much fewer than 1 token per forward pass, then model stealing compute is very large