Lesson 15 of 17

Multi-Head Attention

Multi-Head Attention

Instead of one attention computation, the Transformer runs H independent heads in parallel, each attending to a different subspace of the representation.

Why Multiple Heads?

Different heads can specialize:

  • Head 1 might track syntactic dependencies
  • Head 2 might track semantic similarity
  • Head 3 might track positional patterns

With n_embd=16 and n_head=4, each head operates on a head_dim = 16/4 = 4-dimensional slice.

The Split

Each head h operates on a slice of the Q, K, V vectors:

head_size = n_embd // n_head
hs = h * head_size
q_h = q[hs : hs + head_size]
k_h = [k[hs : hs + head_size] for k in all_keys]
v_h = [v[hs : hs + head_size] for v in all_values]

Concatenate and Project

The outputs of all H heads are concatenated back into a single vector, then projected with attn_wo:

x_attn = []
for h in range(n_head):
    hs = h * head_dim
    # ... compute head_out ...
    x_attn.extend(head_out)   # concatenate

x = linear(x_attn, attn_wo)  # project

Your Task

Implement multi_head_attention(q, keys, values, n_head, head_dim) that runs H heads and concatenates the results.

Python runtime loading...
Loading...
Click "Run" to execute your code.