Softmax
Softmax: Logits to Probabilities
The model produces raw logits — one per vocabulary token. Softmax converts them to probabilities that sum to 1.
Definition
softmax(x)[i] = exp(x[i]) / sum(exp(x[j]) for j)
Numerical Stability
Raw exponents overflow for large inputs. The standard fix: subtract the maximum value before exponentiating. This doesn't change the result (the constant cancels in the division) but keeps numbers finite:
def softmax(logits):
max_val = max(val.data for val in logits)
exps = [(val - max_val).exp() for val in logits]
total = sum(exps)
return [e / total for e in exps]
Note: this version works with Value objects — val - max_val uses Value.__sub__ and .exp() uses Value.exp(). The operations record the computation graph for backpropagation.
Log-Sum-Exp Trick
When computing log(softmax(x)) — which cross-entropy loss requires — the naive approach log(exp(x_i - max) / sum(exp(x_j - max))) can still lose precision. The log-sum-exp trick computes log-probabilities directly:
log_softmax(x)[i] = (x[i] - max) - log(sum(exp(x[j] - max)))
The term log(sum(exp(x[j] - max))) is the "log-sum-exp" and is computed once for all elements. This avoids ever materializing very small probabilities that would underflow to zero before the log:
def log_softmax(logits):
max_val = max(val.data for val in logits)
shifted = [val - max_val for val in logits]
log_sum = sum(s.exp() for s in shifted).log()
return [s - log_sum for s in shifted]
In practice, frameworks like PyTorch fuse softmax + log into log_softmax for exactly this reason.
Cross-Entropy Connection
The model outputs logits → softmax → probabilities → -log(p_target) is the loss. Using log-softmax, this becomes simply -log_softmax(logits)[target], which is both numerically stable and efficient.
During inference we apply softmax again with a temperature:
probs = softmax([l / temperature for l in logits])
Your Task
Implement softmax(logits) and log_softmax(logits) where logits is a list of Value objects. Use the max-subtraction trick for stability in both.