Can I ask what you used to implement the MOE routing? Did you use megablocks? I would love to expand on this research but I can’t find any straightforward implementation of efficient pytorch MOE routing online.
Do you simply iterate over each max probability expert every time you feed in a batch?
Just to close the loop on this one, the official huggingface transformers library just uses a for-loop to achieve MoE. I also implemented a version myself using a for loop and it’s much more efficient than either vanilla matrix multiplication or that weird batch matmul I write up there for large latent and batch sizes.
Can I ask what you used to implement the MOE routing? Did you use megablocks? I would love to expand on this research but I can’t find any straightforward implementation of efficient pytorch MOE routing online.
Do you simply iterate over each max probability expert every time you feed in a batch?
wait a minute… could you just...
you don’t just literally do this do you?
This must in some way be horrifically inefficient, right?
Just to close the loop on this one, the official huggingface transformers library just uses a for-loop to achieve MoE. I also implemented a version myself using a for loop and it’s much more efficient than either vanilla matrix multiplication or that weird batch matmul I write up there for large latent and batch sizes.