Lesson 7 of 17
BPE Tokenization
Byte Pair Encoding (BPE)
Character-level tokenization produces long sequences and a tiny vocabulary. Real LLMs use Byte Pair Encoding (BPE), which iteratively merges the most frequent pair of tokens into a new token.
The Algorithm
Starting from raw UTF-8 bytes:
- Count every adjacent pair of tokens
- Merge the most frequent pair into a new token ID
- Repeat for
num_mergessteps
tokens = list(text.encode("utf-8")) # start with raw bytes
merges = {} # (pair) → new_id
vocab = {i: bytes([i]) for i in range(256)}
next_id = 256
for _ in range(num_merges):
pairs = {}
for i in range(len(tokens) - 1):
pair = (tokens[i], tokens[i+1])
pairs[pair] = pairs.get(pair, 0) + 1
best = max(pairs, key=pairs.get)
merges[best] = next_id
vocab[next_id] = vocab[best[0]] + vocab[best[1]]
# replace all occurrences of best pair
...
next_id += 1
Encoding New Text
To encode unseen text, start from raw bytes and repeatedly apply merges in the order they were learned (lowest ID first):
tokens = list(text.encode("utf-8"))
while True:
# find the pair with the smallest merge ID
best = min(applicable_pairs, key=lambda p: merges[p])
# merge all occurrences
...
Decoding
Decoding is simple — look up each token ID in the vocab and concatenate:
def bpe_decode(tokens, vocab):
return b"".join(vocab[t] for t in tokens).decode("utf-8")
Your Task
Implement bpe_train(text, num_merges) returning (merges, vocab), bpe_encode(text, merges), and bpe_decode(tokens, vocab).
Python runtime loading...
Loading...
Click "Run" to execute your code.