Lesson 10 of 17

Softmax

Softmax: Logits to Probabilities

The model produces raw logits — one per vocabulary token. Softmax converts them to probabilities that sum to 1.

Definition

softmax(x)[i] = exp(x[i]) / sum(exp(x[j]) for j)

Numerical Stability

Raw exponents overflow for large inputs. The standard fix: subtract the maximum value before exponentiating. This doesn't change the result (the constant cancels in the division) but keeps numbers finite:

def softmax(logits):
    max_val = max(val.data for val in logits)
    exps = [(val - max_val).exp() for val in logits]
    total = sum(exps)
    return [e / total for e in exps]

Note: this version works with Value objects — val - max_val uses Value.__sub__ and .exp() uses Value.exp(). The operations record the computation graph for backpropagation.

Log-Sum-Exp Trick

When computing log(softmax(x)) — which cross-entropy loss requires — the naive approach log(exp(x_i - max) / sum(exp(x_j - max))) can still lose precision. The log-sum-exp trick computes log-probabilities directly:

log_softmax(x)[i] = (x[i] - max) - log(sum(exp(x[j] - max)))

The term log(sum(exp(x[j] - max))) is the "log-sum-exp" and is computed once for all elements. This avoids ever materializing very small probabilities that would underflow to zero before the log:

def log_softmax(logits):
    max_val = max(val.data for val in logits)
    shifted = [val - max_val for val in logits]
    log_sum = sum(s.exp() for s in shifted).log()
    return [s - log_sum for s in shifted]

In practice, frameworks like PyTorch fuse softmax + log into log_softmax for exactly this reason.

Cross-Entropy Connection

The model outputs logits → softmax → probabilities → -log(p_target) is the loss. Using log-softmax, this becomes simply -log_softmax(logits)[target], which is both numerically stable and efficient.

During inference we apply softmax again with a temperature:

probs = softmax([l / temperature for l in logits])

Your Task

Implement softmax(logits) and log_softmax(logits) where logits is a list of Value objects. Use the max-subtraction trick for stability in both.

Python runtime loading...
Loading...
Click "Run" to execute your code.