Looks like you use gradient magnitude as your saliency score. I’ve looked at using saliency to guide counterfactual modifications to a text, though my focus was on aiding interpretability rather than adversarial robustness. (Paper).
I’ve found that the normgrad saliency score worked well for highlighting important tokens. I.e., saliency = torch.sum(torch.pow(embedding * gradient, 2)).
Looks like you use gradient magnitude as your saliency score. I’ve looked at using saliency to guide counterfactual modifications to a text, though my focus was on aiding interpretability rather than adversarial robustness. (Paper).
I’ve found that the normgrad saliency score worked well for highlighting important tokens. I.e., saliency = torch.sum(torch.pow(embedding * gradient, 2)).
For more details on normgrad, see: https://arxiv.org/pdf/2004.02866.pdf
Neat! We tried one or two other saliency scores but there’s definitely a lot more experimentation to be done.