The wrapper modules simply wrap existing submodules of the model, and call whatever they are wrapping (in this case self.attn) with the same arguments, and then save some state / do some manipulation of the output. It’s just the syntax I chose to use to be able to both save state from submodules, and manipulate the values of some intermediate state. If you want to see exactly how that submodule is being called, you can look at the llama huggingface source code. In the code you gave, I am adding some vector to the hidden_states returned by that attention submodule.
Thanks, Nina, for sharing the forward pass of Hugging face. I now realize I was skipping the input layer norm calculations. Now, I can reproduce your numbers :)
Thanks for the nice tutorial.
I have a problem understanding your code (I am new to Pytorch). When you are calculating the activations of attention:
def forward(self, *args, **kwargs):
output = self.attn(*args, **kwargs)
if self.add_tensor is not None: output = (output[0] + self.add_tensor,)+output[1:]
self.activations = output[0] return output
What is the argument that is passed to the self.attn function?
I tried passing the following but cannot reproduce your code:
model.layers.layers[0].self_attn(past_key_values[0][0].reshape(1, 10, 32* 128))[0]
model.model.embed_tokens(inputs.input_ids.to(device))
Neither of these can reproduce your results. Can you clarify this?
The wrapper modules simply wrap existing submodules of the model, and call whatever they are wrapping (in this case
self.attn
) with the same arguments, and then save some state / do some manipulation of the output. It’s just the syntax I chose to use to be able to both save state from submodules, and manipulate the values of some intermediate state. If you want to see exactly how that submodule is being called, you can look at the llama huggingface source code. In the code you gave, I am adding some vector to thehidden_states
returned by that attention submodule.Thanks, Nina, for sharing the forward pass of Hugging face. I now realize I was skipping the input layer norm calculations. Now, I can reproduce your numbers :)