Lesson 15 of 17
Multi-Head Attention
Multi-Head Attention
Instead of one attention computation, the Transformer runs H independent heads in parallel, each attending to a different subspace of the representation.
Why Multiple Heads?
Different heads can specialize:
- Head 1 might track syntactic dependencies
- Head 2 might track semantic similarity
- Head 3 might track positional patterns
With n_embd=16 and n_head=4, each head operates on a head_dim = 16/4 = 4-dimensional slice.
The Split
Each head h operates on a slice of the Q, K, V vectors:
head_size = n_embd // n_head
hs = h * head_size
q_h = q[hs : hs + head_size]
k_h = [k[hs : hs + head_size] for k in all_keys]
v_h = [v[hs : hs + head_size] for v in all_values]
Concatenate and Project
The outputs of all H heads are concatenated back into a single vector, then projected with attn_wo:
x_attn = []
for h in range(n_head):
hs = h * head_dim
# ... compute head_out ...
x_attn.extend(head_out) # concatenate
x = linear(x_attn, attn_wo) # project
Your Task
Implement multi_head_attention(q, keys, values, n_head, head_dim) that runs H heads and concatenates the results.
Python runtime loading...
Loading...
Click "Run" to execute your code.