Loading...

Numerical stability of binary cross entropy loss and the log-sum-exp trick


When training a binary classifier, cross entropy (CE) loss is usually used as squared error loss cannot distinguish bad predictions from extremely bad predictions. The CE loss is defined as follows:

    \[L_{CE}(y,t) = tlogy + (1-t)log(1-y)\]

where y is the probability of the sample falling in the positive class (t=1). y = \sigma(z), where \sigma is a sigmoid function.

When implementing CE loss, we could calculate \sigma(z) first and then plug y in the definition of CE loss. However, there is a problem with this in practice. At the beginning of training, a positive example might be confidently classified as a negative example (z<<0, implying y \approx 0). If y is small enough, it could be smaller than the smallest floating point value i.e. numerically zero. Then we get -\infty if we take the log of 0 when computing the cross-entropy. To tackle this potential numerical stability issue, the logistic function and cross-entropy are usually combined into one in package in Tensorflow and Pytorch

    \[L_{CE}(z,t) = L_{CE}(\sigma(z),t) = tlog\frac{1}{1+e^{-x}}+(1-t)log\frac{e^{-x}}{1+e^{-x}}\]

Still, the numerical stability issue is not completely under control since e^{-x} could blow up if z is a large negative number. To tackle this potential problem, the “log-sum-exp” trick is used to shift the center of the exponential sum. The log-sum-exp trick is described as follows

    \[log\sum_{i=1}^ne^{x_i} = a+log\sum_{i=1}^ne^{x_i-a}\]

Using this formula, we can force the greatest value to be zero even if other values would underflow. So a can be max(x_i) in practice.

References:
[1] http://www.cs.toronto.edu/~rgrosse/courses/csc321_2017/readings/L04%20Training%20a%20Classifier.pdf
[2] https://www.xarg.org/2016/06/the-log-sum-exp-trick-in-machine-learning/

1Comment

  • but the log(0) issue still exists when you calculate x=logit(y)=log(y/1-y) when y->0

    Reply

Your email address will not be published. Required fields are marked *