Lesson 7 of 17

BPE Tokenization

Byte Pair Encoding (BPE)

Character-level tokenization produces long sequences and a tiny vocabulary. Real LLMs use Byte Pair Encoding (BPE), which iteratively merges the most frequent pair of tokens into a new token.

The Algorithm

Starting from raw UTF-8 bytes:

  1. Count every adjacent pair of tokens
  2. Merge the most frequent pair into a new token ID
  3. Repeat for num_merges steps
tokens = list(text.encode("utf-8"))     # start with raw bytes
merges = {}                              # (pair) → new_id
vocab  = {i: bytes([i]) for i in range(256)}
next_id = 256

for _ in range(num_merges):
    pairs = {}
    for i in range(len(tokens) - 1):
        pair = (tokens[i], tokens[i+1])
        pairs[pair] = pairs.get(pair, 0) + 1
    best = max(pairs, key=pairs.get)
    merges[best] = next_id
    vocab[next_id] = vocab[best[0]] + vocab[best[1]]
    # replace all occurrences of best pair
    ...
    next_id += 1

Encoding New Text

To encode unseen text, start from raw bytes and repeatedly apply merges in the order they were learned (lowest ID first):

tokens = list(text.encode("utf-8"))
while True:
    # find the pair with the smallest merge ID
    best = min(applicable_pairs, key=lambda p: merges[p])
    # merge all occurrences
    ...

Decoding

Decoding is simple — look up each token ID in the vocab and concatenate:

def bpe_decode(tokens, vocab):
    return b"".join(vocab[t] for t in tokens).decode("utf-8")

Your Task

Implement bpe_train(text, num_merges) returning (merges, vocab), bpe_encode(text, merges), and bpe_decode(tokens, vocab).

Python runtime loading...
Loading...
Click "Run" to execute your code.