Hard-Coding Neural Computation
Previously: Teaser: Hard-coding Transformer Models
Introduction
Transformer models are incredibly powerful for natural language tasks (and they are starting to find uses in many other fields of machine learning). Unfortunately, it is nigh-impossible to interpret what goes on inside them. OR IS IT???
I have found that I can, with a fair amount of effort, hard-code the weights of a transformer model in order to perform some very crude versions of linguistic tasks. So far I have achieved English-to-French translation (on a toy corpus of about 150 sentences), text classification (is a sentence grammatical or not? on a toy corpus of a couple hundred sentences), and sentiment analysis (again on a limited corpus). These results are obviously not impressive compared to the state of the machine learning field, but I am pretty sure that they can all be drastically scaled up with the investment of some time and energy. Unfortunately, I have a fairly demanding day job, and haven’t found the time and energy yet.
All of this is done by inspection (no gradient descent!). The process is a lot like programming, although it is more difficult than programming, at least right now for me. I am fairly certain that better tools and better notation can be developed to make the process easier. It is also almost certainly possible to combine hard-coding with gradient descent approaches to be able to scale these methods up in a slightly less labor-intensive way.
I think that these ideas could prove useful in alignment research—if we understand how a language model works in excruciating detail, it seems drastically more likely that we will be able to reason about and predict various misunderstandings rooted in the ambiguity of language. Given that language is (arguably) a fully general means of interacting with an artificial intelligence, it seems plausible to me that this work is on the critical path to alignment.
Doneness Status
This post is a work-in-progress. I will be editing it as I go, mostly appending more content to the end, but I will also try to fix any errors or unclear parts as I notice them or commenters point them out.
So let’s hard-code some neural computation! I have a very, very messy github repository where I’ve done my initial experiments, if you prefer to just jump into semi-working code. Otherwise, I will do my very best to explain the ideas from scratch in this post. I’m aiming this post at anyone who’s willing to put in the work to understand it, so I’ll try to at least give pointers to necessary background material, of which there is a fair amount.
What Can We Do Already?
Some primitive sentiment analysis using a Vanilla RNN:
Some very simple translation with a Transformer model:
How Should We Measure Success?
Before we explain how we get results, it seems worthwhile to talk about how to measure the performance of such a system. As in traditional machine learning, this can only be measured with respect to some dataset of inputs labeled with desired outputs. For a crude metric, we can look at the fraction of inputs that receive the desired output, versus the fraction that receive some other output. We can also look at more complex metrics, such as BLEU or ROUGE for sequence-to-sequence tasks.
In traditional machine learning, performance can only be (meaningfully) measured on a holdout set that was not used to train the algorithm. This is because performance tends to be much, much higher (if you’re using the right architecture and hyperparameters for your task) on the training set (the set of data that was used to train the algorithm) than it will be for the test set (the set of data that has been held out). The whole purpose and challenge of machine learning, of course, is to build models that generalize to unseen data.
A similar phenomenon occurs in this work, where data that has been examined by the programmer and run through the algorithm and used to inform updates to the rules, will typically be data that the resulting network does disproportionately well on. After all, if you miss some edge case, but then see it in your testing, you have the opportunity to fix it.
On the other hand, the programmer will presumably have a native command of at least one language, so it is at least possible for the programmer to anticipate some phenomena before seeing them in the “training data”. Thus, it seems unfair to gradient descent and deep learning to compare accuracies in the low-data regime where I have been stuck so far by my lack of free time.
The ultimate ambition of this work would be to go toe-to-toe with a comparably-sized Transformer model trained in the traditional way on a modern-sized data set. This might require several people-years of focused effort though.
Some Useful Notation
The first thing we are going to do is introduce some very unconventional notation for vectors and matrices. (We won’t need any more information about linear algebra than is contained in the Wikipedia article, but we will assume that you are either familiar with them or have paused and read that article.)
We will pick a set of “axes” that we will call “semes”. (This word comes from semiotics, as do a few other terms we will use. I believe I’m using them in a way that is compatible with their technical meaning in semiotics, but feel free to think of this as a nonsense word that we are coining.) Each seme will be identified with a short string, often a word. So, we might have semes “wombat”, “peregrine”, and “pig”. These play a role very similar to variable names in traditional programming, so we will generally choose them to be meaningful. Common semes that I actually use are “noun”, “verb”, etc.
We then will write vectors using these semes, for example for the vector that is 1 in the direction and −1 in the direction . We can also use coefficients, so that denotes the vector that is 2.1 in the direction and −3.2 in the direction . There are two ways of thinking about this—you can either think of the various semes as being completely orthogonal to each other, forming an orthonormal basis of whatever vector space we are in. Or you can think of them as arbitrary vectors that we are using as a (possibly overcomplete) basis. In general, both will be useful; I generally think of syntactic information as being best represented in a fully orthonormal basis, while semantic information makes much more sense as being drawn from a very overcomplete basis.
Matrices will be written in the form for a matrix that would be conventionally represented as .
In code, we will write them like this:
vec1: 2.1 pig -3.2 wombat
mat1: 1.1 pig>wombat +2.3 wombat>pig -4.5 pig>peregrine + 0.9 peregrine>peregrine
for the above vector and matrix.
As with vectors, there are two ways to think of the matrix notation. In the first way, the semes form an orthonormal basis, and we are just using them to identify which pairs of coordinates get which coefficient. But, we can also think of as being 1.1 times the outer product of and . This second view will not be necessary for the contents of this post, but it is necessary to understand some of the ways I envision being able to combine this work with gradient descent-based learning.
It is also worth pointing out that, if we multiply matrices and vectors with the vector on the left, that actually maps the vector to the vector . (Although, potentially confusingly, maps to .) For this reason, we will prefer left-multiplication in our neural networks later, because it makes this particular notation way easier to think with.
Tokenization and Word Embeddings
In deep NLP, the first couple steps are about getting rid of words and replacing them with inputs that can actually be understood by a deep network. The first step is to take a string and break it into some number of discrete chunks called “tokens”. In principle, we could feed things in letter-by-letter, and people have gotten semi-decent results doing that in the past, but it’s a lot less labor-intensive in this context to use full words as the unit of tokenization. This is actually a mild break from most Transformer models used today, which generally make use of a “subword vocabulary” which contains a mixture of whole words and parts of words like “ing” or “particul”.
Let’s take an example sentence and tokenize it, just to be sure that we understand this process. Consider
The rain in Spain is mainly on the plain, while treefuls of weevils are gleefully evil.
["The", "rain", "in", "Spain", "is", "mainly", "on", "the", "plain", ",", "while", "treefuls", "of", "weevils", "are", "gleefully", "evil", "."]
Some things worth emphasizing:
We don’t use tokens for whitespace (spaces, tabs, etc.)
Punctuation such as commas and periods will get a token of its own
Additionally, we will case-normalize our inputs by making everything lower-case. This cuts down on some repetitive work and is relatively common with deep models. We also include a special SOS (start of sentence) token and an EOS (end of sentence) token. So the above example should really look like this:
The rain in Spain is mainly on the plain, while treefuls of weevils are gleefully evil.
["SOS", "the", "rain", "in", "spain", "is", "mainly", "on", "the", "plain", ",", "while", "treefuls", "of", "weevils", "are", "gleefully", "evil", ".", "EOS"]
The second step in transforming a string of text into the sort of inputs that a deep network prefers is to do a “word embedding lookup”. Here, each token is replaced by a fixed vector, so that we get a matrix of shape [num_tokens, word_embedding_dim]. Because the first axis (the “sequence dimension”) is not semantically the same as the second axis (the “embedding dimension”), we will not use our special matrix notation, but will instead think of this as a list of vectors, one for each token.
So let’s look at some word embeddings! Here are some pronouns (note that we’re describing here a fragment of a flavor of English that includes the gender-neutral singular “they” in addition to the plural “they”):
i: +nom +sg +1st +pro
you: +nom +sg +2nd +pro
he: +masc +nom +sg +3rd +pro
she: +fem +nom +sg +3rd +pro
it: +neut +sg +3rd +pro +expletive
me: +acc +sg +1st +pro
you: +sg +pl +2nd +pro
him: +masc +acc +sg +3rd +pro
we: +nom +pl +1st +pro
they: +enby +nom +sg +pl +3rd +pro
them: +enby +acc +sg +pl +3rd +pro
us: +acc +pl +1st +pro
them: +acc +pl +3rd +pro
my: +gen +sg +1st +pro
our: +gen +pl +1st +pro
his: +masc +gen +sg +3rd +pro
her: +fem +gen +acc +sg +3rd +pro
its: +neut +gen +sg +3rd +pro
their: +enby +gen +sg +pl +3rd +pro
myself: +1st +reflexive +sg +pro
ourselves: +1st +reflexive +pl +pro
yourself: +2nd +reflexive +sg +pro
yourselves: +2nd +reflexive +pl +pro
himself: +3rd +reflexive +sg +masc +pro
herself: +3rd +reflexive +sg +fem +pro
itself: +3rd +reflexive +sg +neut +pro
themselves: +3rd +reflexive +pl +enby +pro
oneself: +3rd +reflexive +sg +pro
Here are some verbs:
is: +be +verb +3rdsg +copula
be: +be +verb +plain +copula
was: +be +verb +preterite +copula +helper
did: +do +helper +verb +preterite +agentlack +themeposs
do: +do +helper +verb +plain +agentlack +themeposs
does: +do +helper +verb +3rdsg +agentlack +themeposs
have: +have +plain +helper +verb +agentposs +themeposs
has: +have +3rdsg +helper +verb +agentposs +themeposs
can: +can +plain +helper +modal
could: +can +preterite +helper +modal
may: +may +plain +3rdsg +helper +modal
might: +may +helper +modal
must: +must +plain +helper +modal
shall: +shall +plain +helper +modal
should: +shall +preterite +helper +modal
will: +will +plain +3rdsg +helper +modal
would: +will +preterite +helper +modal
ought: +ought +modal +helper +modal
dare: +dare +modal +helper +modal
accuse: +accuse +verb +plain +agentlack +themelack
accused: +accuse +verb +preterite +agentlack +themelack
accuses: +accuse +verb +3rdsg +agentlack +themelack
appear: +appear +verb +plain +agentlack +complementposs
appeared: +appear +verb +preterite +agentlack +complementposs
appears: +appear +verb +3rdsg +agentlack +complementposs
ate: +eat +verb +preterite +agentlack +patientposs
beam: +beam +verb +plain +agentlack
beamed: +beam +verb +preterite +agentlack
beams: +beam +verb +3rdsg +agentlack
bend: +bend +verb +plain +agentlack +patientposs
bent: +bend +verb +preterite +agentlack +patientposs
bends: +bend +verb +3rdsg +agentlack +patientposs
bled: +bleed +verb +preterite +agentlack +patientposs
bleed: +bleed +verb +plain +agentlack +patientposs
bleeds: +bleed +verb +3rdsg +agentlack +patientposs
blew: +blow +verb +preterite +agentlack +patientposs
blow: +blow +verb +plain +agentlack +patientposs
blows: +blow +verb +3rdsg +agentlack +patientposs
braid: +braid +verb +plain +agentlack +patientlack
braided: +braid +verb +preterite +agentlack +patientlack
braids: +braid +verb +3rdsg +agentlack +patientlack
breathe: +breathe +verb +plain +agentlack
breathed: +breathe +verb +preterite +agentlack
breathes: +breathe +verb +3rdsg +agentlack
break: +break +verb +plain +agentlack
breaks: +break +verb +3rdsg +agentlack
broke: +break +verb +preterite +agentlack
brush: +brush +verb +plain +agentlack +patientlack
brushed: +brush +verb +preterite +agentlack +patientlack
brushes: +brush +verb +3rdsg +agentlack +patientlack
carve: +carve +verb +plain +agentlack +patientposs
carved: +carve +verb +preterite +agentlack +patientposs
carves: +carve +verb +3rdsg +agentlack +patientposs
chase: +chase +verb +plain +agentlack +patientlack
chased: +chase +verb +preterite +agentlack +patientlack
chases: +chase +verb +3rdsg +agentlack +patientlack
chuckle: +chuckle +verb +plain +agentlack
chuckled: +chuckle +verb +preterite +agentlack
chuckles: +chuckle +verb +3rdsg +agentlack
came: +come +verb +preterite +agentlack
come: +come +verb +plain +agentlack
comes: +come +verb +3rdsg +agentlack
cook: +cook +verb +plain +agentlack +patientposs
cooked: +cook +verb +preterite +agentlack +patientposs
cooks: +cook +verb +3rdsg +agentlack +patientposs
cough: +cough +verb +plain +agentlack +patientposs
coughed: +cough +verb +preterite +agentlack +patientposs
coughs: +cough +verb +3rdsg +agentlack +patientposs
cried: +cry +verb +preterite +agentlack
cries: +cry +verb +3rdsg +agentlack
cry: +cry +verb +plain +agentlack
cut: +cut +verb +plain +preterite +agentlack +patientlack
cuts: +cut +verb +3rdsg +agentlack +patientlack
You’ll note something strange about verbs (and nouns and other “content” words): they almost all have one seme that’s just themselves again! What gives? These are semantic semes, which are much harder to reason about than the other semes, which are syntactic. As we said earlier, syntactic semes should be thought of as an orthonormal basis of whatever size, but semantic semes are more usefully thought of as living in a small dimensional space, where they aren’t mutually orthogonal. (But! all semantic semes should be thought of as perfectly orthogonal to all syntactic semes, and vice versa.) For the time being, I restrict myself to a relatively limited vocabulary and just use the semantic semes as if they were orthogonal. For grammaticality classification, which is the domain I have worked hardest on, the semantic semes are not particularly relevant. For translation, the only really important thing is that they are able to pick out the corresponding word in the target language. (Assuming there is a straightforward single-word translation in the target language, which there has mostly been in the toy examples I have considered thus far, but which in general is not the case.)
Let’s embed a sentence! Consider:
The cat sat on the mat.
This tokenizes to
["SOS", "the", "cat", "sat", "on", "the", "mat", ".", "EOS"]
Which then embeds to
SOS: +sos
the: +det
cat: +cat +sg +noun
sat: +sit +verb +preterite +agentlack
on: +on +prep
the: +det
mat: +mat +sg +noun
.: +punct +period
EOS: +eos
In mathematical notation, we would write this as
For further clarity, let’s give a gloss for each seme we’re using:
semes: sos # start of sentence
. eos # end of sentence
det # determiner, a linguistic class that contains articles, demonstratives, and some other stuff
cat # meowing animal (semantic seme)
sg # singular in number
noun # nouns, the class of object/concept words
sit # sitting down (semantic seme)
verb # verbs, the class of action words
preterite # one of the past tenses in English
agentlack # to be grammatical, this verb needs an agent
on # the preposition (semantic seme)
prep # prepositions, the class of words denoting relationships
mat # something to sit on (semantic seme)
punct # punctutation
period # specifically this guy: .
Sentiment Analysis
Sentiment analysis refers to the task of extracting from a piece of natural language the overall sentiment that the speaker has towards whatever thing they’re talking about. For instance, in a movie or product review, does the author recommend the movie or product? This is generally considered a pretty straightforward task for machine learning algorithms.
An extremely interpretable algorithm for sentiment analysis is given in VADER: A Parsimonious Rule-based Model for Sentiment Analysis of Social Media Text [PDF], by C. Hutto, E. Gilbert. (2014) We have implemented a similar algorithm inside of a two-layer vanilla RNN, which we will describe below. However, we wanted to first note that this algorithm is only a crude sketch of VADER, and its shortcomings should not be held against Hutto and Gilbert.
Why a vanilla RNN rather than the OMG SO MUCH BETTER LSTM? Well, vanilla RNN’s are significantly simpler to understand, and their disadvantages (vanishing and exploding gradients, primarily) are only really relevant when you’re actually using gradient descent! So let’s do this the easy way and stick to a vanilla RNN.
First, a few preliminaries about the architecture we will be using:
Let denote the output of layer at time-step . Note that the superscript isn’t an exponent, it’s just a convenient place to put another index. There will be no exponents anywhere in this network; they are all superscripts. (Here time-step just means the index of the token. So time-step 1 will be the first token, time-step 2 will be the second token, and so on.) will be a special initial state before we read any tokens, and we will set it to be the zero vector. will be the output of the word-embedding layer, so it will just be the embedding of the -th token after we look it up.
We then define the recurrence (for ):
Here is the good old logistic sigmoid function, and is the identity matrix. (Practitioners will note that the use of the identity matrix here is some sort of residual-like connection.)
We further define a pooling layer, and then a fully-connected or dense layer.
# our set of semes
semes: stop # "stop words", in this case anything not needed for sentiment analysis
positive negative negation contrastive intensifier
lessener intensepunctuation
xa xb xc xd xe ya yb yc yd ye # a bunch of anonymous variables to represent intermediate computations
lexicon:
",": stop
".": stop
"I'll": stop
At: stop
It: stop
The: stop
Today: stop
VADER: stop
a: stop
all: stop
and: stop
are: stop
at: stop
book: stop
by: stop
characters: stop
dialog: stop
get: stop
is: stop
it: stop
of: stop
plot: stop
the: stop
was: stop
FUNNY: positive
GOOD: positive
GREAT: positive
HANDSOME: positive
LOL: positive
SMART: positive
funny: positive
good: positive
great: positive
handsome: positive
lol: positive
smart: positive
SUX: negative
bad: negative
horrible: negative
sux: negative
uncompelling: negative
"!": intensepunctuation
"!!!": intensepunctuation
very: intensifier
VERY: intensifier
uber: intensifier
FRIGGIN: intensifier
only: lessener
kinda: lessener
not: negation
nor: stop
isnt: negation
Not: negation
But: contrastive
but: contrastive
rnn_layer1:
A: positive>xa negative>ya positive>xb negative>yb
negation>negation intensifier>intensifier lessener>lessener
B: intensifier>xa intensifier>ya lessener>xb lessener>yb
negation>negation 0.5intensifier>intensifier 0.5lessener>lessener
bias: -xa -xb -ya -yb
rnn_layer2:
A: xa>xc ya>yc xb>xd yb>yd negation>xc negation>xd negation>yc
negation>yd positive>xe negation>xe negative>ye negation>ye
B: '' # in yaml, which is the formatting language I use to type these programs, you need to do this to specify an empty string, which corresponds to the zero matrix
bias: -xc -xd -yc -yd -xe -ye
dense1:
C: positive>positive negative>negative 2xa>positive 0.25xb>positive
ya>negative 0.25yb>negative xc>negative xd>negative yc>positive
yd>positive -2xc>positive -2xd>positive -2yc>negative
-2yd>negative xe>negative ye>positive -xe>positive -ye>negative
c: ''
examples: # examples modified from https://github.com/cjhutto/vaderSentiment
- VADER is smart , handsome , and funny .
- VADER is smart , handsome , and funny !
- VADER is very smart , handsome , and funny .
- VADER is VERY SMART , handsome , and FUNNY .
- VADER is VERY SMART , handsome , and FUNNY !!!
- VADER is VERY SMART , uber handsome , and FRIGGIN FUNNY !!!
- VADER is not smart , handsome , nor funny .
- The book was good .
- It isnt a horrible book .
- The book was only kinda good .
- The plot was good , but the characters are uncompelling and the dialog is not great .
- Today SUX !
- Today only kinda sux ! But I’ll get by , lol
- Not bad at all
This is sufficient to generate the scores at the beginning of this post. The scores on the given examples are not all that inaccurate. Lots more work could obviously be done on this network, and I’d love it if people feel like working on this network or other later-discussed networks in the comments.
Transformer Overview
We give here a brief overview of the transformer architecture, for those unfamiliar with it. This will essentially be an accelerated recap of Jay Alammar’s Illustrated Transformer, which I consider to be the best friendly introduction to the architecture.
Transformer in its full sequence-to-sequence glory has an encoder stack and a decoder stack. For text classification purposes, one generally just uses the encoder stack with a few simple layers at the end. The encoder stack is made up of a bunch of Transformer encoder blocks, each of which is the same architecturally, but each of which has its own learnable/settable weights that allow it to specialize and do its own particular task in the grand scheme of the network.
The decoder stack is also made up of a series of architecturally identical Transformer layers, again each with their own learnable/settable weights that allow them to specialize into their own unique role. The decoder layers are similar to the encoder layers, but a little bit more complex.
So now let’s dive inside the Transformer layer and see how they tick!
So we have self-attention layers, encoder-decoder attention layers, and the feed-forward layers. The two types of attention layers are generally considered to be the innovative, “important” part of Transformer, but I have found in trying to hard-code weights for transformer (and years of research by many people has also found) that the feed-forward layers are crucial to being able to learn complex functions. (Well, that’s not entirely true. I think someone in some paper managed to sort of smuggle the feed-forward layer into a computation that looks like self-attention over a learnable set of parameters, without losing any accuracy. But for our purposes the feed-forward layer is important.)
We’ll next dive into the various layers and see how to hard-code their parameters. We’ll start with the easiest layer to understand: the feed-forward layer.
Transformer Feed-Forward Layers
The standard transformer feed-forward layer can be described pretty simply as:
Here ReLU is a new kind of non-linearity, the “rectified linear unit”. and are matrices called “weights”, and and are vectors called “biases”. “Parameters” refers to either weights or biases, although people will often refer to biases as “weights” also.
Like most non-linearities in deep learning, ReLU is an element-wise function, i.e., you apply it to each coordinate of the vector independently. , so if x is negative, , and otherwise .
So let’s hard-code a feed-forward layer! This just requires picking values for .
semes: apple banana cherry durian yum yuck
mat1: apple>apple apple>yum banana>banana banana>yum cherry>yuck durian>yuck
bias1: -yum -yuck
mat2: apple>yum banana>yum -yum>yum yuck>yuck
bias2: '' # a zero vector
The semantics we are trying to encode here is that OR should be mapped to , while AND should be mapped to . This results in the following output:
You could argue that should displace . This can be done, but it seems to require a third layer (I haven’t given it much thought just now, so maybe it’s doable in two layers), so it would have to be split across two different Transformer layers.
Preparing for Self-Attention: Positional Encodings
Before we can really dive into self-attention layers, it is useful to talk about positional encodings. This is a fairly technical aspect of the Transformer architecture, but I’ve found that it can be made fairly interpretable by thinking about it in the right way.
The traditional way to think about positional embeddings is just to read the following code, and then add some handwaving around “relative positional offsets can be encoded as linear combinations of sines and cosines”. This is all correct, but it doesn’t really yield enough understanding (for me at least) to hard-code things around the positional embeddings.
# Code from https://www.tensorflow.org/tutorials/text/transformer
def get_angles(pos, i, d_model):
angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
return pos * angle_rates
def positional_encoding(position, d_model):
angle_rads = get_angles(np.arange(position)[:, np.newaxis],
np.arange(d_model)[np.newaxis, :],
d_model)
# apply sin to even indices in the array; 2i
angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
# apply cos to odd indices in the array; 2i+1
angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
pos_encoding = angle_rads[np.newaxis, …]
return pos_encoding
I prefer to think of the various sines and cosines as the hands of a bunch of clocks running at different speeds. If we have 512 dimensions in our positional embeddings, then there will be 256 clocks. The sine is just the y-coordinate of the clock hand, and the cosine is the x-coordinate of the clock hand.
At time t=0, we look at the first word in the sequence. All clocks point in the same direction.
At time t=1, we look at the second word in the sequence. The slowest clock (all the way on the left) has advanced a tiny bit, while the fastest clock (all the way on the right) has advanced more.
At time t=4, the slowest clock has advanced a decent amount, while the fastest clock has advanced about a quarter rotation from where it started.
Now that we have a more grounded understanding, we can ask questions like, if I have the positional encoding of “very”, how do I get the positional encoding of the next word? Well, the next word is one time-step further, so each clock should advance by the amount that that particular clock advances over one time step. We can “point” to that positional embedding by using a particular angle offset for each clock, which then translates into the specific linear combination of sines and cosines referenced earlier. Thus, I have an intentionally quite redundant way to refer to a specific number of time steps in the future.
The different speeds also provide me the ability to point to ranges of time-steps relatively easily. If I want to point to a very narrow range, I can use the fastest clock, which will pick out a very specific time-step, with little room for error. If I want to refer to a broader time-range, I can use a slower clock, which will have the effect that my pointing will be somewhat evenly spread over a wide range of time-steps. Since I have so many clocks to choose from, I can be quite precise in what sort of time interval I point at.
Finally, as a technical note, it should be understood that the range of clock speeds are chosen such that no clock does a full loop. Otherwise, our pointing might accidentally point at something that’s done a full loop without understanding that that is occurring.
The notion of “pointing” will prove to be a very apt way of thinking about the mechanism of self-attention. Basically, words that interact with each other in the self-attention layer can be thought of as pointing at the word they interact with.
Self-Attention
Finally, we come to the heart of the Transformer model, the self-attention layer. This can be mathematically expressed as
where are matrices called the queries, key, and values, respectively. The \sqrt{d} factor, often called a “temperature”, is important for gradient descent training, where it improves stability. (I would guess that it improves stability by making the attention matrix relatively balanced early on in training.) For us, we will have the opposite problem; it’s much easier to think about approximately sparse matrices, so we will actually use something like this, with being a fairly large scalar.
There’s a lot to unpack here, so let’s come up with a very simple example and work through it together.
Suppose we want to distinguish between these two sentences:
She saw a red apple. # grammatical
She saw a red. # not grammatical (ish, you can provide contexts in which it is natural)
We will try not to go too far down the rabbit hole of what makes something grammatical or not—here we’re just trying to encode the simple rule (which does have limited exceptions) that adjectives can’t just hang out without modifying anything, or being used with a linking verb, or in some other way being “licensed” by the other words in the sentence.
So we would like to be able to detect when an adjective seems to be modifying a noun in the usual way, versus when it is not. In my special programming language for self-attention layers, this can be done like this:
H1a: # name of the head
docstring: Modification layer. Specifically pairs of the form
Q K, where Q comes before K and Q modifies K. Q
must be an adjective or an adverb to use this rule.
pos: # special notation for interacting with the positional encodings
Q: 0
K: +1
x1: # adjective modifies noun: red apple
Q: adjective
K: noun
x2: # adverb modifies verb: quickly write
Q: adverb
K: verb
x3: # adverb modifies adjective: very slow
Q: adverb
K: adjective
x4: # adverb modifies adverb: very slightly
Q: adverb
K: adverb
x5: # everything else hits filler
Q: verb noun det filler pro verb noun det filler pro
K: filler
int: noun>licensed verb>licensed adverb>licensed
adjective>licensed
The queries and keys live in a small dimensional embedding space that I call key-space. (Typical size in a standard transformer is 64 dimensions—much smaller than the 512/768/1024 hidden size.) There is one query vector for each token, and one key vector for each token. We take the dot product of every query vector with every key vector, and that gives us what are called “attention logits”. Applying the softmax function to the attention logits gives us attention probabilities, which make up what is generally call the “attention matrix”. So let’s compute all of these values for the above program and our example sentences to get a sense of how all of this works.
Let’s suppose that we have just run the word embedding lookup layer, so that we have the default embedding for each word, but no information yet about how the words are interacting with each other. That might look like this (omitting the positional embeddings for the time being):
SOS: +filler +sos
she: +pro +fem +sg +nom
saw: +saw +verb +agentlack +perceptlack
a: +det +sg
red: +red +adjective
apple: +apple +noun +sg
EOS: +filler +eos
------------
SOS: +filler +sos
she: +pro +fem +sg +nom
saw: +saw +verb +agentlack +perceptlack
a: +det +sg
red +red +adjective
EOS: +filler +eos
The queries will then be computed as follows (again, omitting the positional components):
SOS: +2 x5
she: +2 x5
saw: +2 x5
a: +2 x5
red: +x1
apple: +2 x5
EOS: +2 x5
------------
SOS: +2 x5
she: +2 x5
saw: +2 x5
a: +2 x5
red: +x1
EOS: +2 x5
The keys will look like this (again, omitting the positional components):
SOS: +x5
she: 0 # 0-vector
saw: +x2
a: 0
red: +x3
apple: +x1
EOS: +x5
------------
SOS: +x5
she: 0 # 0-vector
saw: +x2
a: 0
red: +x3
EOS: +x5
Supposing that x1, x2, x3, x4, and x5 are orthonormal, this creates the following attention logits, where we omit zeros for brevity (and we’re still omitting positional encodings):
# written in the form Q>K
SOS>SOS: 2
SOS>EOS: 2
she>SOS: 2
she>EOS: 2
saw>SOS: 2
saw>EOS: 2
a>SOS: 2
a>EOS: 2
red>apple: 1
apple>SOS: 2
apple>EOS: 2
EOS>SOS: 2
EOS>EOS: 2
------------
SOS>SOS: 2
SOS>EOS: 2
she>SOS: 2
she>EOS: 2
saw>SOS: 2
saw>EOS: 2
a>SOS: 2
a>EOS: 2
EOS>SOS: 2
EOS>EOS: 2
We will continue to ignore the positional encodings for the rest of this example, since they’re not needed, and don’t drastically change the attention matrix. (They would be needed in the case that there are two nouns, in which case the above weights would generate a tie between red pointing to apple and red pointing to the other noun, which we would want to resolve a certain way based on the rule that, in English, an adjective is close to the noun it modifies and before it, with some limited counter-examples.)
Now let’s look at the values! For simplicity in our programs, we actually combine the value projection and what is usually called the out-projection. We call the combined quantity the “interpretant” (a term from semiotics). The interpretants for us are as follows:
SOS: 0
she: 0
saw: +licensed
a: 0
red: +licensed
apple: +licensed
EOS: 0
------------
SOS: 0
she: 0
saw: +licensed
a: 0
red: +licensed
EOS: 0
Multiplying the interpretants V by the attention matrix , we get the following outputs (for large )
SOS: 0
she: 0
saw: 0
a: 0
red: +licensed
apple: 0
EOS: 0
------------
SOS: 0
she: 0
saw: 0
a: 0
red: 0
EOS: 0
Using the residual connection that surrounds the self-attention layer, we then receive final outputs:
SOS: +filler +sos
she: +pro +fem +sg +nom
saw: +saw +verb +agentlack +perceptlack
a: +det +sg
red: +red +adjective +licensed
apple: +apple +noun +sg
EOS: +filler +eos
------------
SOS: +filler +sos
she: +pro +fem +sg +nom
saw: +saw +verb +agentlack +perceptlack
a: +det +sg
red +red +adjective
EOS: +filler +eos
Thus, this self-attention layer has managed to compute the fact that the word “red” is licensed in the first, grammatical sentence, but not licensed in the second, questionable sentence. This can be used by downstream layers to declare the second sentence to be ungrammatical.
Phew!
More to Come!
We’ll next see how to hard-code the weights of an actual Transformer. This will involve explaining the structures of Transformer layers, which will take a fair amount of time. In the meantime, please check out The Illustrated Transformer and Transformers from Scratch to get a head start on understanding them, or dive face-first into my 1800-line poorly-commented grammaticality classifier.
(Moderation note: added to the Alignment Forum from LessWrong.)
I’m confused by your notation for feed-forward layers.
What justifies re-using the same labels (“apple” etc.) for
the coordinates of x
the coordinates of x⋅A, i.e. the basis in which the nonlinearity operates
?
If we want to express what the individual components of basis (2) mean in terms of the original space, we can either talk about which vectors/semes are mapped to them by A, or which vectors/semes they get mapped to by B.
But your labels don’t correspond to either of these interpretations. Instead, it looks like you are following rules of the form “the 4th component of every basis is called ‘yum’,” which leads you to label a coordinate “yum” even though it’s neither mapped from “yum” by A, nor mapped to “yum” by B.
This notation also seems to require the basis (2) to have the same number of elements as (1), which generally will not be the case. In transformers, (2) is typically larger by a factor of 4. The logic of your example, meanwhile, can be expressed using a smaller nonlinearity basis of 3 elements:
neuron1=ReLU(cherry+durian−1)
neuron2=ReLU(apple+banana−1)
neuron3=ReLU(apple+banana)
yum=neuron3−neuron2
yuck=−1∗neuron1
with some arbitrary choices about which multiplicative constants to absorb into A and a vs. which to absorb into B.
Thanks for your comments/questions, they’re very insightful.
In general, there are as many encoding spaces in a Transformer as there are computational nodes, and a traditional Transformer will have little incentive to use the same semantics for any two of the spaces. (There’s a little bit of an incentive because of the residual connections, which will (I think?) kind of tie the semantics of the various hidden-size-sized embeddings spaces.)
In particular, the middle layer of the dense-relu-dense feedforward layer is usually chosen to be significantly larger (4x) than the hidden size, and so it’s not even theoretically possible to represent it using the same basis. I’ve found that it sometimes makes sense to use anonymous seme names like x1 x2 x3 etc in the feed-forward layer for this reason. In my experience so far I’ve found the feed-forward layers to be most useful for conjunctions and disjunctions—and there are a quadratic number of possible conjunctions and disjunctions of even two neurons, let alone 3 or 4. So it seems to me that this might give a tiny hint as to why people have found that the intermediate embedding space of the feed-forward layer needs to be so large.
Of course, there is a potentially huge gap between what I am clever enough to think of as a use for them and what good old gradient descent is clever enough to think of. We can only easily lower-bound the potential uses of them; upper-bounding the capabilities of a component will prove much more challenging.
I don’t fully understand how the embeddings are done.
Can you spell out one of the examples?
It would be helpful for me to see how the semes map to the actual matrix.
Added an example sentence and its embeddings. Will add more examples overall. Thanks for commenting!
Re: how this interacts with Alignment Research:
Another use is for sanity checking existing interpretability techniques. For example, to check if particular neurons identified as curve detectors via interpretability techniques were indeed curve detectors, Chris Olah spent a few hours replacing the curve-detecting neurons with handwritten curve detector neurons. (He found that the interpretability techniques were able to give qualitatively similar results for both the original neurons and the handwritten neurons. More impressively, he also found that replacing the curve detecting neurons with his handwritten neurons was able to recover ~60% of the drop in accuracy compared to removing the original neurons entirely [reported in footnote 9].)
Very nice post. It is certainly useful to do this exercise of manually encoding language rules into the weights of a transformer in order to better understand the machinery involved.
There is a long history of attempting to parse natural language with hand design rules and heuristics. The general consensus now is that hand engineering is insufficient, and some learning from data is necessary. To me it seems that this direction inherits the problems of these old fashioned language systems since you are codifying your own hand designed heuristics and rules into the network weights.
Do you see a way to introduce learning from data without sacrificing the interpretability that your approach provides?
There are a number of ways to combine this approach with learning, but I haven’t had time to try any of them yet. Some ideas I have thought of:
Use hard-coded weights, plus some random noise, to initialize the weights of a transformer that you then train in the traditional fashion
Doesn’t really help with interpretability or alignment, but might(???) help with performance
Write out all the weight and bias parameters as combinations of semes and outer products of semes, then learn seme embeddings by gradient descent
Semantic seme embeddings could be initialized from something like WordNet relationships, or learned with word2vec, to automate those guys
You could do smallish amounts of gradient descent to suggest new rules to add, but then add them by hand
Still would be very slow
Perhaps it is possible to start with a strong learned transformer and gradually identify human-legible rules that it is using, and replacing those specific parts with hard-coding
Could prove very difficult!!!
It seems almost certain to me that hard-coding weights would at least help us build the muscles needed to recognize what is going on, to the extent that we are able to