[手写系列] 2.loss(BCELoss&CrossEntropyLoss)

589 阅读2分钟

本文已参与[新人创作礼]活动,一起开启掘金创作之路。

1. BCE_Loss

import torch
from torch import nn
import torch.nn.functional as F
# BCE_Loss 交叉熵损失
class BCELosswithLogits(nn.Module): #nn.Module父类,BCELosswithLogits子类
    def __init__(self,pos_weight=1,reduction='mean'): #没有默认值的放在前面,有默认值的放在后面。
        super(BCELosswithLogits,self).__init__() #父类初始化
#         self.pos_weight = pos_weight # 1
        self.reduction = reduction     # ‘mean’
    def forward(self, logits, target):# p , y
        # logits: [N, *], target:[N, *]
        logits = F.sigmoid(logits)
        loss = - target * torch.log(logits) - (1-target) * torch.log(1-logits)
        if self.reduction=='mean':
            loss = loss.mean()
        elif self.reduction=='sum':
            loss = loss.sum()
        return loss
# 计算损失 sum_loss
bce = BCELosswithLogits(1,'sum')
logits = torch.rand(100,1)
target = torch.rand(100,1)
target = torch.tensor([[1.0] if _>0.5 else [0.0] for _ in target]) #tensor是float类型的!!!!!
print(bce.forward(logits,target))
# 计算损失 mean_loss
bce2 = BCELosswithLogits(1,'mean')
print(bce2.forward(logits,target))

2. Cross_Entropy_Loss

class CrossEntropyLoss(torch.nn.Module):
    def __init__(self,reduction='mean'):
        super(CrossEntropyLoss,self).__init__() #父类初始化
        self.reduction = reduction
        # N,C,H,W : “batch个数”,"类别",高,宽。
    def forward(self, logits, target): 
        # 图像的p和y   logits: [N, C, H, W] 值是概率 , target: [N, H, W] 值是0/1/2/3 (共4个类)
        # 其他的p和y        logits: [N, C, 1] , target: [N, 1]          
        # loss = sum(-y_i * log(c_i))
        if logits.dim()>2:
            logits = logits.view(logits.size(0), logits.size(1), -1)  #[N,C,HW] 列0 行1,由内而外从.size(0)->.size(n)。
            logits = logits.transpose(1,2)                            #[N,HW,C]
            logits = logits.contiguous().view(-1,logits.size(2))      #[NHW,C]
        target = target.view(-1,1)                                    #[NHW,1]
        logits = F.log_softmax(logits,1) #对logits行上(1)进行softmax。
        logits = logits.gather(1,target) #对logits行上(1)选择对应的target值 0/1/2/3(哪个类)。
        loss = -1 * logits
        if self.reduction=='mean':
            loss = loss.mean()
        elif self.reduction=='sum':
            loss = loss.sum()
        return loss
# 非图像的p和y        logits: [N, C, 1] , target: [N, 1]       
logits_ce = torch.rand(100,3,1)
target_ce = torch.randint(3,(100,1)) #范围[0-3],(100*1)的矩阵。
# 计算损失 mean_loss / sum_loss
s = CrossEntropyLoss('mean')
print(s.forward(logits_ce,target_ce))
s = CrossEntropyLoss('sum')
print(s.forward(logits_ce,target_ce))

3. BCE_FocalLoss

class BCEFocalLosswithLogits(nn.Module):
    def __init__(self, gamma=0.2, alpha=0.6, reduction='mean'):
        super(BCEFocalLosswithLogits, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, logits, target):
        # logits: [N, H, W], target: [N, H, W]
        logits = F.sigmoid(logits)
        alpha = self.alpha
        gamma = self.gamma
        loss = - alpha * (1 - logits) ** gamma * target * torch.log(logits) - \
               (1 - alpha) * logits ** gamma * (1 - target) * torch.log(1 - logits)
        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()
        return loss

4.CrossEntropyLoss

class CrossEntropyFocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=0.2, reduction='mean'):
        super(CrossEntropyFocalLoss, self).__init__()
        self.reduction = reduction
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits, target):
        # logits: [N, C, H, W], target: [N, H, W]
        # loss = sum(-y_i * log(c_i))
        if logits.dim() > 2:
            logits = logits.view(logits.size(0), logits.size(1), -1)  # [N, C, HW]
            logits = logits.transpose(1, 2)   # [N, HW, C]
            logits = logits.contiguous().view(-1, logits.size(2))    # [NHW, C]
        target = target.view(-1, 1)    # [NHW,1]

        pt = F.softmax(logits, 1)
        pt = pt.gather(1, target).view(-1)   # [NHW]
        log_gt = torch.log(pt)

        if self.alpha is not None:
            # alpha: [C]
            alpha = self.alpha.gather(0, target.view(-1))   # [NHW]
            log_gt = log_gt * alpha
            
        loss = -1 * (1 - pt) ** self.gamma * log_gt

        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()
        return loss