逻辑回归中的数值稳定

715 阅读6分钟

引言

本文我们来探讨一下逻辑回归中的数值稳定问题,所谓数值稳定即不会出现数值溢出(上溢或下溢)的问题。常见的表现是除00,或返回nan

上溢和下溢

我们在计算机中表示实数时,几乎总会引入一些近似误差。当许多操作复合时,即使是理论上可行的算法,在实践时也可能回导致算法失效。

一种极具毁灭性的舍入误差是下溢(underflow)。当接近零的数被四舍五入为零时发生下溢。我们通常要避免被零除或避免取零的对数。

令一种数值错误形式是上溢(overflow)。当大量级的数被近似为无穷大或负无穷大时发生上溢。

Sigmoid函数稳定性问题

我们知道Sigmoid函数公式为:

σ(x)=11+exp(x)(1)\sigma(x) = \frac{1}{1 + \exp(-x)} \tag{1}

对应的图像如下:

Sigmoid函数图像

其中包含一个exp(x)\exp(-x),我们看一下exe^x的图像:

202112110854

从上图可以看出,如果xx很大,exe^x会非常大,而很小就没事(不会溢出),变成无限接近00

当Sigmoid函数中的xx负的特别多,那么exp(x)\exp(-x)就会变成\infty,就出现了上溢;

那么如何解决这个问题呢?σ(x)\sigma(x)可以表示成两种形式:

σ(x)=11+exp(x)=exp(x)1+exp(x)(2)\sigma(x) = \frac{1}{1 + \exp(-x)} = \frac{\exp(x)}{1 + \exp(x)} \tag{2}

x0x \geq 0时,我们根据exe^{x}的图像,我们取11+exp(x)\frac{1}{1 + \exp(-x)}的形式;

x<0x < 0时,我们取exp(x)1+exp(x)\frac{\exp(x)}{1 + \exp(x)}

# 原来的做法
def sigmoid_naive(x):
  return 1 / (1 + math.exp(-x))
  
# 优化后的做法
def sigmoid(x):
  if x < 0:
    return math.exp(x) / (1 + math.exp(x))
  else:
    return 1 / (1 + math.exp(-x))
   

然后用不同的数值进行测试:

> sigmoid_naive(2000)
1.0
> sigmoid(2000)
1.0
> sigmoid_naive(-2000)
OverflowError: math range error
> sigmoid(-2000)
0.0

如果传入-2000,普通的实现会出现溢出,而优化后的版本不会。

但是这里的实现包含了if判断,同时只判断了一个标量而不是向量。

有一种更好的方法是,我们的逻辑回归只计算出logit,然后将logit传入损失函数,这里的logit说的是逻辑回归中线性变换的输出(加权和加上偏置)。

数值稳定的BCE损失

在pytorch的github中,有一段代码 github.com/pytorch/pyt…

class StableBCELoss(nn.modules.Module):
       def __init__(self):
             super(StableBCELoss, self).__init__()
       def forward(self, input, target):
             neg_abs = - input.abs()
             loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
             return loss.mean()

笔者又在一篇文章[^1]中找到了对代码进行的解释,如下图:

我们先来分析下为什么是数值稳定的。

import numpy as np

# 先定义一个数值稳定的sigmoid
def sigmoid(x):
  if x < 0:
    return np.exp(x) / (1 + np.exp(x))
  else:
    return 1 / (1 + np.exp(-x))
 
# 逻辑回归损失的常规实现
def bce_loss_naive(y, z):
  return -y * np.log(sigmoid(z)) - (1-y) * np.log(1 - sigmoid(z))

# 数值稳定版逻辑回归损失
def bce_loss(y, z):
  neg_abs = - np.abs(z)
  return np.clip(z, a_min=0,a_max=None) - y * z + np.log(1 + np.exp(neg_abs))

接下来我们进行测试,假设zz是一个较大的数,比如20002000,我们知道σ(2000)\sigma(2000)会输出11,那么L(1,σ(2000))L(1,\sigma(2000))应该为00

> z = 2000
> y = 1
> sigmoid(z)
1.0
> bce_loss(y, z) # 数值稳定版本的
0.0
> bce_loss_naive(y, z)
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:2: RuntimeWarning: divide by zero encountered in log
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:2: RuntimeWarning: invalid value encountered in double_scalars
nan

数值稳定版本的还是稳!常规实现就会碰到nan了。

我们在假设zz是一个负的较大的数,比如2000-2000,那么σ(2000)=0\sigma(-2000)=0,即L(0,σ(2000))L(0,\sigma(-2000))也应该为00

> z = -2000
> y = 0
> sigmoid(z)
0.0
> bce_loss(y, z)
0.0
> bce_loss_naive(y, z)
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:2: RuntimeWarning: divide by zero encountered in log
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:2: RuntimeWarning: invalid value encountered in double_scalars
nan

嗯,看起来不错。下面我们仔细分析下这段代码的正确性,没问题我们才可以放心地实现。

我们把y^=σ(z)\hat y = \sigma(z)​带入逻辑回归的损失函数中得:

LCE(y,z)=ylog(σ(z))(1y)log(1σ(z))(3)L_{CE}(y,z) = -y \cdot \log(\sigma(z))-(1-y) \cdot \log(1-\sigma(z)) \tag 3

就是说如果logit,也就是z0z \geq 0,那么代码实现中的损失就变成了zzy+log(1+ez)z - zy + \log(1 + e^{-z})。我们来验证一下。

这里分为两种情况,分别是真实标签y=0y=0y=1y=1

(1) 当z0z \geq 0y=0y=0时,此时代码计算的是

z+log(1+ez)=logez+log(1+ez)=log(1+ez)ez=log(1+ez)=log11+ez=log1+ezez1+ez=log(1ez1+ez)=log(111+ez)\begin{aligned} z + \log(1+e^{-z}) &= \log e^z + \log(1 + e^{-z}) \\ &= \log(1 + e^{-z}) \cdot e^z \\ &= \log(1 + e^z) \\ &= - \log \frac{1}{1 + e^z} \\ &= -\log \frac{1 + e^z - e^z}{1 + e^z} \\ &= - \log \left(1 - \frac{e^z}{1+e^z} \right) \\ &= -\log \left(1 - \frac{1}{1 + e^{-z}} \right) \end{aligned}

公式(3)(3)计算为log(1σ(z))- \log(1 - \sigma(z)),结果是一样的。

(2)当z0z \geq 0y=1y=1时,代码计算的是

zz+log(1+ez)=log(1+ez)=log(11+ez)\begin{aligned} z - z + \log(1 + e^{-z}) &= \log (1+ e^{-z}) \\ &= - \log \left( \frac{1}{1+ e^{-z}}\right) \end{aligned}

公式(3)(3)计算的是log(σ(z))- \log(\sigma(z)),结果也是一样的。

如果z<0z < 0,那么代码实现中的损失就变成了zy+log(ez+1)-zy + \log(e^z + 1),我们也来验证一下。

(3)当z<0z < 0y=1y=1时,代码计算的是

z+log(ez+1)=logez+log(ez+1)=log(ez+1)ez=log(1+ez)=log(11+ez)\begin{aligned} -z + \log(e^z + 1) &= \log e^{-z} + \log(e^z + 1) \\ &= \log (e^z+1)\cdot e^{-z} \\ &= \log(1 + e^{-z}) \\ &= - \log \left( \frac{1}{1+ e^{-z}}\right) \end{aligned}

公式(3)(3)计算的是log(σ(z))- \log(\sigma(z)),结果也是一样的。

(4)当z<0z<0y=0y=0时,代码计算的是

log(ez+1)=log11+ez=log1+ezez1+ez=log(1ez1+ez)=log(111+ez)\begin{aligned} \log(e^z + 1) &= - \log \frac{1}{1 + e^z} \\ &= -\log \frac{1 + e^z - e^z}{1 + e^z} \\ &= - \log \left(1 - \frac{e^z}{1+e^z} \right) \\ &= -\log \left(1 - \frac{1}{1 + e^{-z}} \right) \end{aligned}

公式(3)(3)计算为log(1σ(z))- \log(1 - \sigma(z))​,结果也是一样的。

我们证明了这种代码实现的正确性。

下面就可以为我们的metagrad实现数值稳定版的BCE损失了。

等等,还需要先实现两个函数:absclip

实现Clip操作

clip()像个夹子,把Tensor中的值限制在最小值和最大值之间。

class Clip(_Function):
    def forward(ctx, x: ndarray, x_min=None, x_max=None) -> ndarray:
        if x_min is None:
            x_min = np.min(x)
        if x_max is None:
            x_max = np.max(x)

        ctx.save_for_backward(x, x_min, x_max)
        return np.clip(x, x_min, x_max)

    def backward(ctx, grad: ndarray) -> ndarray:
        x, x_min, x_max = ctx.saved_tensors
        mask = (x >= x_min) * (x <= x_max)
        return grad * mask

只有在[x_min,x_max]之间的元素才有梯度。

实现Abs操作

abs()即求绝对值,图像如下:

绝对值函数图像

我们知道,按照定义绝对值在00处是不可导的

ddxx=xx\frac{d}{dx} |x| = \frac{x}{|x|}

因为除00是无意义的,但是我们和PyTorch的做法一致,当x=0x=0时,令其导数也为00

class Abs(_Function):
    def forward(ctx, x: ndarray) -> ndarray:
        ctx.save_for_backward(x)
        return np.abs(x)

    def backward(ctx, grad: ndarray) -> ndarray:
        x, = ctx.saved_tensors
        # x中元素为0的位置,返回0
        # 否则返回+1/-1
        return grad * np.where(x == 0, 0, x / np.abs(x))

实现稳定版BCE损失

现在实现起来就很顺畅:

def binary_cross_entropy(input: Tensor, target: Tensor, reduction: str = "mean") -> Tensor:
    '''

    :param input: logits
    :param target: 真实标签 0或1
    :param reduction: binary cross entropy loss
    :return:
    '''

    neg_abs = - abs(input)
    errors = input.clip(x_min=0) - input * target + (1 + neg_abs.exp()).log()

    N = len(target)

    if reduction == "mean":
        loss = errors.sum() / N
    elif reduction == "sum":
        loss = errors.sum()
    else:
        loss = errors
    return loss

当然,为了稳妥起见,我们写一个测试用例:

def test_binary_cross_entropy():
    N = 10
    x = np.random.randn(N)
    y = np.random.randint(0, 1, (N,))

    mx = Tensor(x, requires_grad=True)

    tx = torch.tensor(x, dtype=torch.float32, requires_grad=True)
    ty = torch.tensor(y, dtype=torch.float32)

    mo = torch.binary_cross_entropy_with_logits(tx, ty).mean()
    to = F.binary_cross_entropy(mx, y)

    assert np.allclose(mo.data,
                       to.numpy())

    mo.backward()
    to.backward()

    assert np.allclose(mx.grad.data,
                       tx.grad.numpy())

确保它是通过的:

============================= test session starts =============================
collecting ... collected 1 item

test_cross_entropy.py::test_binary_cross_entropy PASSED                  [100%]

======================== 1 passed, 1 warning in 0.48s =========================

BCE损失的实现类实际上我们不需要改:

class BCELoss(_Loss):
    def __init__(self, reduction: str = "mean") -> None:
        super().__init__(reduction)

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        '''

        :param input: logits
        :param target:  真实标签 0或1
        :return:
        '''
        return F.binary_cross_entropy(input, target, self.reduction)

其实之前的逻辑回归实现偶尔会遇到invalid value,其实就是数值稳定问题。

我们把模型的初始化权重打印出来,以便重现这个错误:

    # print(model.linear.weight) Tensor([[-0.29942604  0.78735491]], requires_grad=True) 有问题的权重
    # 使用之前有问题的权重
    model.linear.weight.assign([[-0.29942604, 0.78735491]])
    print(f"using weight: {model.linear.weight}")

然后再次训练:

  using weight: Tensor([[-0.29942605  0.7873549 ]], requires_grad=True)
  5%|▌         | 10026/200000 [00:05<01:45, 1796.27it/s]Train -  Loss: 0.6216351389884949. Accuracy: 69.6969696969697

 10%|█         | 20023/200000 [00:11<01:39, 1810.46it/s]Train -  Loss: 0.6169218420982361. Accuracy: 71.71717171717172

 15%|█▍        | 29933/200000 [00:16<01:33, 1812.38it/s]Train -  Loss: 0.6122889518737793. Accuracy: 74.74747474747475

 20%|██        | 40009/200000 [00:22<01:27, 1823.05it/s]Train -  Loss: 0.6077350378036499. Accuracy: 79.79797979797979

 25%|██▌       | 50028/200000 [00:27<01:23, 1804.55it/s]Train -  Loss: 0.6032587885856628. Accuracy: 81.81818181818181

 30%|██▉       | 59893/200000 [00:33<01:17, 1804.93it/s]Train -  Loss: 0.5988588929176331. Accuracy: 81.81818181818181

 35%|███▍      | 69928/200000 [00:38<01:11, 1828.28it/s]Train -  Loss: 0.5945340394973755. Accuracy: 82.82828282828282

 40%|████      | 80000/200000 [00:44<01:06, 1811.00it/s]Train -  Loss: 0.5902827978134155. Accuracy: 83.83838383838383

 45%|████▌     | 90037/200000 [00:50<01:01, 1782.85it/s]Train -  Loss: 0.5861039757728577. Accuracy: 85.85858585858585

 50%|████▉     | 99996/200000 [00:55<00:55, 1789.86it/s]Train -  Loss: 0.5819962620735168. Accuracy: 86.86868686868686

 55%|█████▍    | 109880/200000 [01:01<00:50, 1772.11it/s]Train -  Loss: 0.577958345413208. Accuracy: 86.86868686868686

 60%|█████▉    | 119899/200000 [01:06<00:45, 1760.05it/s]Train -  Loss: 0.5739889144897461. Accuracy: 87.87878787878788

 65%|██████▍   | 129886/200000 [01:12<00:40, 1752.00it/s]Train -  Loss: 0.5700867176055908. Accuracy: 87.87878787878788

 70%|██████▉   | 139954/200000 [01:18<00:33, 1791.69it/s]Train -  Loss: 0.5662506222724915. Accuracy: 88.88888888888889

 75%|███████▍  | 149921/200000 [01:24<00:31, 1572.65it/s]Train -  Loss: 0.5624792575836182. Accuracy: 89.8989898989899

 80%|████████  | 160025/200000 [01:30<00:22, 1805.44it/s]Train -  Loss: 0.5587714314460754. Accuracy: 91.91919191919192

 85%|████████▍ | 169954/200000 [01:35<00:16, 1778.81it/s]Train -  Loss: 0.5551260113716125. Accuracy: 90.9090909090909

 90%|████████▉ | 179929/200000 [01:41<00:11, 1777.69it/s]Train -  Loss: 0.551541805267334. Accuracy: 91.91919191919192

 95%|█████████▍| 189974/200000 [01:47<00:05, 1825.03it/s]Train -  Loss: 0.5480176210403442. Accuracy: 90.9090909090909

100%|██████████| 200000/200000 [01:52<00:00, 1774.94it/s]
Train -  Loss: 0.544552206993103. Accuracy: 90.9090909090909

这次没有问题了。

总结

本文我们实现了数值稳定的逻辑回归损失,下篇文章我们来实现更常用的数值稳定版Softmax回归损失。

References

  1. How do Tensorflow and Keras implement Binary Classification and the Binary Cross-Entropy function?