# fast.ai 深度学习笔记(七下)

65 阅读38分钟

模型[1:38:30]

class Empty(nn.Module): 
    def forward(self,x): 
        return x

models = ConvnetBuilder(resnet34, 0, 0, 0, custom_head=Empty())
learn = ConvLearner(md, models)
learn.summary()
class StdUpsample(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.conv = nn.ConvTranspose2d(nin, nout, 2, stride=2)
        self.bn = nn.BatchNorm2d(nout)

    def forward(self, x): 
        return self.bn(F.relu(self.conv(x)))
flatten_channel = Lambda(lambda x: x[:,0])
simple_up = nn.Sequential(
    nn.ReLU(),
    StdUpsample(512,256),
    StdUpsample(256,256),
    StdUpsample(256,256),
    StdUpsample(256,256),
    nn.ConvTranspose2d(256, 1, 2, stride=2),
    flatten_channel
)

考虑到我们想要一个知道汽车长什么样的东西,我们可能想要从一个预训练的 ImageNet 网络开始。所以我们将从 ResNet34 开始。使用ConvnetBuilder,我们可以获取我们的 ResNet34 并添加一个自定义头部。自定义头部将是一些上采样的东西,现在我们将做一些非常愚蠢的事情,就是我们只是做一个 ConvTranspose2d,批量规范化,ReLU。

这就是我说的 - 任何人都可以在不看任何笔记本的情况下构建这个,或者至少你有来自以前课程的信息。这里没有任何新东西。所以最后,我们有一个单一的过滤器。现在这将给我们一个批量大小为 1 乘以 128 乘以 128。但我们想要的是批量大小为 128 乘以 128。所以我们必须去掉那个单元轴,所以我在这里有一个 lambda 层。Lambda 层非常有帮助,因为没有这个 lambda 层,它只是通过索引 0 来删除那个单元轴,没有 lambda 层,我将不得不创建一个自定义类,具有自定义的前向方法等等。但通过创建一个 lambda 层来执行一个自定义操作,我现在可以将其放入 Sequential 中,这样就更容易了。

PyTorch 的人们对这种方法有点傲慢。Lambda 层实际上是 fastai 库的一部分,而不是 PyTorch 库的一部分。而且 PyTorch 讨论板上的人们说“是的,我们可以给人们这个”,“是的,这只是一行代码”,但他们从不鼓励他们过于频繁地使用 Sequential。所以你看。

这是我们的自定义头部[1:40:36]。所以我们将有一个 ResNet 34 进行下采样,然后一个非常简单的自定义头部,非常快速地上采样,希望这样做一些事情。我们将使用阈值为 0.5 的准确度并打印出指标。

models = ConvnetBuilder(resnet34, 0, 0, 0, custom_head=simple_up)
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5)
learn.lr_find()
learn.sched.plot()
'''
94%|█████████▍| 30/32 [00:05<00:00,  5.48it/s, loss=10.6]
'''

lr=4e-2
learn.fit(lr,1,cycle_len=5,use_clr=(20,5))
'''
epoch      trn_loss   val_loss   <lambda>                  
    0      0.124078   0.133566   0.945951  
    1      0.111241   0.112318   0.954912                  
    2      0.099743   0.09817    0.957507                   
    3      0.090651   0.092375   0.958117                   
    4      0.084031   0.086026   0.963243
[0.086025625, 0.96324310824275017]
'''

经过几个时代,我们得到了 96%的准确率。这好吗[1:40:56]?96%的准确率好吗?希望对这个问题的答案是取决于。这是为了什么?答案是 Carvana 想要这个,因为他们想要能够拍摄他们的汽车图像并将它们剪切并粘贴到异国情调的蒙特卡洛背景或其他地方(这是蒙特卡洛的地方,而不是模拟)。为了做到这一点,你需要一个非常好的蒙版。你不想留下后视镜,缺少一个车轮,或者包括一点背景之类的东西。那看起来很愚蠢。所以你需要一些非常好的东西。所以只有 96%的像素正确并不听起来很好。但我们真的不知道直到我们看到它。所以让我们看看。

learn.save('tmp')
learn.load('tmp')
py,ay = learn.predict_with_targs()
ay.shape
'''
(1008, 128, 128)
'''

所以这是我们想要剪切的正确版本[1:41:54]

show_img(ay[0]);

这是 96%准确的版本。所以当你看到它时,你会意识到“哦,是的,准确地获取 96%的像素实际上很容易,因为所有外部部分都不是汽车,所有内部部分都是汽车,而真正有趣的部分是边缘。所以我们需要做得更好。

show_img(py[0]>0);

让我们解冻,因为到目前为止我们只训练了自定义头部。让我们做更多。

learn.unfreeze()
learn.bn_freeze(True)
lrs = np.array([lr/100,lr/10,lr])/4
learn.fit(lrs,1,cycle_len=20,use_clr=(20,10))
'''
epoch      trn_loss   val_loss   <lambda>                   
    0      0.06577    0.053292   0.972977  
    1      0.049475   0.043025   0.982559                   
    2      0.039146   0.035927   0.98337                    
    3      0.03405    0.031903   0.986982                   
    4      0.029788   0.029065   0.987944                   
    5      0.027374   0.027752   0.988029                   
    6      0.026041   0.026718   0.988226                   
    7      0.024302   0.025927   0.989512                   
    8      0.022921   0.026102   0.988276                   
    9      0.021944   0.024714   0.989537                   
    10     0.021135   0.0241     0.990628                   
    11     0.020494   0.023367   0.990652                   
    12     0.01988    0.022961   0.990989                   
    13     0.019241   0.022498   0.991014                   
    14     0.018697   0.022492   0.990571                   
    15     0.01812    0.021771   0.99105                    
    16     0.017597   0.02183    0.991365                   
    17     0.017192   0.021434   0.991364                   
    18     0.016768   0.021383   0.991643                   
    19     0.016418   0.021114   0.99173
[0.021113895, 0.99172959849238396]
'''

再经过一段时间,我们得到了 99.1%。这好吗?我不知道。让我们看看。

learn.save('0')
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))

实际上不是。它完全错过了左侧的后视镜,右侧也错过了很多。底部的边缘明显错了。当我们尝试剪裁时,这些事情完全会影响到,所以还不够好。

ax = show_img(denorm(x)[0])
show_img(py[0]>0, ax=ax, alpha=0.5);

ax = show_img(denorm(x)[0])
show_img(y[0], ax=ax, alpha=0.5);

512x512

让我们尝试放大。很好的一点是,当我们将其放大到 512x512 时(确保减少批量大小,因为你会耗尽内存),有更多的信息供其使用,因此我们的准确性提高到 99.4%,事情一直在变得更好。

TRAIN_DN = 'train'
MASKS_DN = 'train_masks_png'
sz = 512
bs = 16
x_names = np.array([Path(TRAIN_DN)/o for o in masks_csv['img']])
y_names = np.array([
    Path(MASKS_DN)/f'**{o[:-4]}**_mask.png' 
    for o in masks_csv['img']
])
((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, x_names, y_names)
len(val_x),len(trn_x)
'''
(1008, 4080)
'''
tfms = tfms_from_model(
    resnet34, sz, 
    crop_type=CropType.NO,
    tfm_y=TfmType.CLASS, 
    aug_tfms=aug_tfms
)
datasets = ImageData.get_ds(
    MatchedFilesDataset, 
    (trn_x,trn_y),
    (val_x,val_y), 
    tfms, 
    path=PATH
)
md = ImageData(
    PATH, datasets, bs, 
    num_workers=8, 
    classes=None
)
denorm = md.trn_ds.denorm
x,y = next(iter(md.aug_dl))
x = denorm(x)

这是真实的。

fig, axes = plt.subplots(4, 4, figsize=(10, 10))
for i,ax in enumerate(axes.flat):
    ax=show_img(x[i], ax=ax)
    show_img(y[i], ax=ax, alpha=0.5)
plt.tight_layout(pad=0.1)

转存失败,建议直接上传图片文件

simple_up = nn.Sequential(
    nn.ReLU(),
    StdUpsample(512,256),
    StdUpsample(256,256),
    StdUpsample(256,256),
    StdUpsample(256,256),
    nn.ConvTranspose2d(256, 1, 2, stride=2),
    flatten_channel
)
models = ConvnetBuilder(resnet34, 0, 0, 0, custom_head=simple_up)
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5)]
learn.load('0')
learn.lr_find()
learn.sched.plot()
'''
85%|████████▌ | 218/255 [02:12<00:22,  1.64it/s, loss=8.91]
'''

lr=4e-2
learn.fit(lr,1,cycle_len=5,use_clr=(20,5))
'''
epoch      trn_loss   val_loss   <lambda>                     
    0      0.02178    0.020653   0.991708  
    1      0.017927   0.020653   0.990241                     
    2      0.015958   0.016115   0.993394                     
    3      0.015172   0.015143   0.993696                     
    4      0.014315   0.014679   0.99388
[0.014679321, 0.99388032489352751]
'''
learn.save('tmp')
learn.load('tmp')
learn.unfreeze()
learn.bn_freeze(True)
lrs = np.array([lr/100,lr/10,lr])/4
learn.fit(lrs,1,cycle_len=8,use_clr=(20,8))
'''
epoch      trn_loss   val_loss   mask_acc                     
    0      0.038687   0.018685   0.992782  
    1      0.024906   0.014355   0.994933                     
    2      0.025055   0.014737   0.995526                     
    3      0.024155   0.014083   0.995708                     
    4      0.013446   0.010564   0.996166                     
    5      0.01607    0.010555   0.996096                     
    6      0.019197   0.010883   0.99621                      
    7      0.016157   0.00998    0.996393
[0.0099797687, 0.99639255659920833]
'''
learn.save('512')
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))
ax = show_img(denorm(x)[0])
show_img(py[0]>0, ax=ax, alpha=0.5);

ax = show_img(denorm(x)[0])
show_img(y[0], ax=ax, alpha=0.5);

事情一直在变得更好,但我们仍然有一些小黑色块状物。所以让我们调整到 1024x1024。

1024x1024

所以让我们调整到 1024x1024,批量大小减少到 4。现在这是相当高分辨率的了,再训练一段时间,99.6%,99.8%!

sz = 1024
bs = 4
tfms = tfms_from_model(
    resnet34, sz, 
    crop_type=CropType.NO,
    tfm_y=TfmType.CLASS, 
    aug_tfms=aug_tfms
)
datasets = ImageData.get_ds(
    MatchedFilesDataset, 
    (trn_x,trn_y), 
    (val_x,val_y), 
    tfms, 
    path=PATH
)
md = ImageData(
    PATH, datasets, bs, 
    num_workers=8, 
    classes=None
)
denorm = md.trn_ds.denorm
x,y = next(iter(md.aug_dl))
x = denorm(x)
y = to_np(y)
fig, axes = plt.subplots(2, 2, figsize=(8, 8))
for i,ax in enumerate(axes.flat):
    show_img(x[i], ax=ax)
    show_img(y[i], ax=ax, alpha=0.5)
plt.tight_layout(pad=0.1)

simple_up = nn.Sequential(
    nn.ReLU(),
    StdUpsample(512,256),
    StdUpsample(256,256),
    StdUpsample(256,256),
    StdUpsample(256,256),
    nn.ConvTranspose2d(256, 1, 2, stride=2),
    flatten_channel,
)
models = ConvnetBuilder(resnet34, 0, 0, 0, custom_head=simple_up)
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5)]
learn.load('512')learn.lr_find()
learn.sched.plot()
'''
85%|████████▌ | 218/255 [02:12<00:22,  1.64it/s, loss=8.91]
'''

lr=4e-2
learn.fit(lr,1,cycle_len=2,use_clr=(20,4))
'''
epoch      trn_loss   val_loss   <lambda>                       
    0      0.01066    0.011119   0.996227  
    1      0.009357   0.009696   0.996553
[0.0096957013, 0.99655332546385511]
'''
learn.save('tmp')
learn.load('tmp')
learn.unfreeze()
learn.bn_freeze(True)
lrs = np.array([lr/100,lr/10,lr])/8
learn.fit(lrs,1,cycle_len=40,use_clr=(20,10))
'''
epoch      trn_loss   val_loss   mask_acc                       
    0      0.015565   0.007449   0.997661  
    1      0.01979    0.008376   0.997542                       
    2      0.014874   0.007826   0.997736                       
    3      0.016104   0.007854   0.997347                       
    4      0.023386   0.009745   0.997218                       
    5      0.018972   0.008453   0.997588                       
    6      0.013184   0.007612   0.997588                       
    7      0.010686   0.006775   0.997688                       
    8      0.0293     0.015299   0.995782                       
    9      0.018713   0.00763    0.997638                       
    10     0.015432   0.006575   0.9978                         
    11     0.110205   0.060062   0.979043                      
    12     0.014374   0.007753   0.997451                       
    13     0.022286   0.010282   0.997587                       
    14     0.015645   0.00739    0.997776                       
    15     0.013821   0.00692    0.997869                       
    16     0.022389   0.008632   0.997696                       
    17     0.014607   0.00677    0.997837                       
    18     0.018748   0.008194   0.997657                       
    19     0.016447   0.007237   0.997899                       
    20     0.023596   0.008211   0.997918                       
    21     0.015721   0.00674    0.997848                       
    22     0.01572    0.006415   0.998006                       
    23     0.019519   0.007591   0.997876                       
    24     0.011159   0.005998   0.998053                       
    25     0.010291   0.005806   0.998012                       
    26     0.010893   0.005755   0.998046                       
    27     0.014534   0.006313   0.997901                       
    28     0.020971   0.006855   0.998018                       
    29     0.014074   0.006107   0.998053                       
    30     0.01782    0.006561   0.998114                       
    31     0.01742    0.006414   0.997942                       
    32     0.016829   0.006514   0.9981                         
    33     0.013148   0.005819   0.998033                       
    34     0.023495   0.006261   0.997856                       
    35     0.010931   0.005516   0.99812                        
    36     0.015798   0.006176   0.998126                       
    37     0.021636   0.005931   0.998067                       
    38     0.012133   0.005496   0.998158                       
    39     0.012562   0.005678   0.998172
[0.0056782686, 0.99817223208291195]
'''
learn.save('1024')
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))
ax = show_img(denorm(x)[0])
show_img(py[0][0]>0, ax=ax, alpha=0.5);

ax = show_img(denorm(x)[0])
show_img(y[0,...,-1], ax=ax, alpha=0.5);

show_img(py[0][0]>0);

show_img(y[0,...,-1]);

现在如果我们看一下掩模,它们实际上看起来不错。这看起来相当不错。那么我们能做得更好吗?答案是肯定的。

U-Net

笔记本 / 论文

U-Net 网络非常了不起。使用之前的方法,我们的预训练 ImageNet 网络被压缩到 7x7,然后再扩展到 224x224(1024 被压缩到比 7x7 大得多)。然后再次扩展出来,这意味着它必须以某种方式在小版本中存储关于更大版本的所有信息。实际上,关于更大版本的大部分信息实际上已经在原始图片中。因此,这种压缩和解压似乎不是一个很好的方法。

因此,U-Net 的想法来自于这篇出色的论文,在这篇论文中,它实际上是在生物医学图像分割这个非常特定的领域中发明的。但事实上,基本上每一个与分割有关的 Kaggle 获胜者最终都使用了 U-Net。这是每个 Kaggle 参与者都知道的最佳实践之一,但在更多的学术圈中,这已经存在至少几年了,很多人仍然没有意识到这是迄今为止最好的方法。

这里是基本的想法。在左侧是向下路径,我们从 572x572 开始,然后将网格大小减半 4 次,然后在右侧是向上路径,我们将网格大小扩大 4 次。但我们还做的一件事是,在每个减半网格大小的点,我们实际上将这些激活复制到向上路径,并将它们连接在一起。

在右下角可以看到,这些红色箭头是最大池化操作,这些绿色箭头是向上采样,然后这些灰色箭头是复制。所以我们复制并连接。换句话说,经过几次卷积后的输入图像被复制到输出中,连接在一起,现在我们可以使用所有经过所有向下和向上的信息,还有输入像素的略微修改版本。以及输入像素的略微修改版本,因为它们是通过这里上来的。所以我们拥有所有向下和向上的丰富性,但也有一个略微不那么粗糙的版本,然后是一个略微不那么粗糙的版本,然后是一个真正简单的版本,它们都可以组合在一起。这就是 U-Net。这是一个很酷的想法。

我们在 carvana-unet 笔记本中。所有这些与之前的代码相同。

%matplotlib inline
%reload_ext autoreload
%autoreload 2
from fastai.conv_learner import *
from fastai.dataset import *
from fastai.models.resnet import vgg_resnet50

import jsontorch.backends.cudnn.benchmark=True

数据

PATH = Path('data/carvana')
MASKS_FN = 'train_masks.csv'
META_FN = 'metadata.csv'
masks_csv = pd.read_csv(PATH/MASKS_FN)
meta_csv = pd.read_csv(PATH/META_FN)
def show_img(im, figsize=None, ax=None, alpha=None):
    if not ax: 
        fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(im, alpha=alpha)
    ax.set_axis_off()
    return axTRAIN_DN = 'train-128'
MASKS_DN = 'train_masks-128'
sz = 128
bs = 64
nw = 16
TRAIN_DN = 'train'
MASKS_DN = 'train_masks_png'
sz = 128
bs = 64
nw = 16
class MatchedFilesDataset(FilesDataset):
    def __init__(self, fnames, y, transform, path):
        self.y=y
        assert(len(fnames)==len(y))
        super().__init__(fnames, transform, path)
    def get_y(self, i): 
        return open_image(os.path.join(self.path, self.y[i]))
    def get_c(self): 
        return 0
x_names = np.array([Path(TRAIN_DN)/o for o in masks_csv['img']])
y_names = np.array([
    Path(MASKS_DN)/f'{o[:-4]}_mask.png' 
    for o in masks_csv['img']
])
val_idxs = list(range(1008))
((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, x_names, y_names)
aug_tfms = [
    RandomRotate(4, tfm_y=TfmType.CLASS),
    RandomFlip(tfm_y=TfmType.CLASS),
    RandomLighting(0.05, 0.05, tfm_y=TfmType.CLASS)
]
tfms = tfms_from_model(
    esnet34, sz, 
    crop_type=CropType.NO, 
    tfm_y=TfmType.CLASS, 
    aug_tfms=aug_tfms
)
datasets = ImageData.get_ds(
    MatchedFilesDataset, 
    (trn_x,trn_y), 
    (val_x,val_y), 
    tfms, 
    ath=PATH
)
md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)
denorm = md.trn_ds.denorm
x,y = next(iter(md.trn_dl))
x.shape,y.shape
'''
(torch.Size([64, 3, 128, 128]), torch.Size([64, 128, 128]))
'''

简单的上采样

一开始,我有一个简单的上采样版本,只是为了再次向你展示非 U-net 版本。这次,我将加入一个称为 dice 指标的东西。Dice 非常类似,如你所见,与 Jaccard 或 I over U 非常相似。只是有一点小差别。基本上是交集除以并集,稍微调整了一下。我们要使用 dice 的原因是 Kaggle 竞赛使用了这个指标,而且要获得高 dice 分数比获得高准确度要困难一些,因为它真的在看正确像素与你的像素的重叠部分。但它非常相似。

在 Kaggle 竞赛中,表现良好的人得到了大约 99.6 点,而获胜者得到了大约 99.7 点。

f = resnet34
cut,lr_cut = model_meta[f]def get_base():
    layers = cut_model(f(True), cut)
    return nn.Sequential(*layers)
def dice(pred, targs):
    pred = (pred>0).float()
    return 2. * (pred*targs).sum() / (pred+targs).sum()

这是我们的标准上采样。

class StdUpsample(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.conv = nn.ConvTranspose2d(nin, nout, 2, stride=2)
        self.bn = nn.BatchNorm2d(nout)

    def forward(self, x): 
        return self.bn(F.relu(self.conv(x)))

这一切和以前一样。

class Upsample34(nn.Module):
    def __init__(self, rn):
        super().__init__()
        self.rn = rn
        self.features = nn.Sequential(
            rn, nn.ReLU(),
            StdUpsample(512,256),
            StdUpsample(256,256),
            StdUpsample(256,256),
            StdUpsample(256,256),
            nn.ConvTranspose2d(256, 1, 2, stride=2)
        )

    def forward(self,x): 
        return self.features(x)[:,0]
class UpsampleModel():
    def __init__(self,model,name='upsample'):
        self.model,self.name = model,name

    def get_layer_groups(self, precompute):
        lgs = list(split_by_idxs(children(self.model.rn), [lr_cut]))
        return lgs + [children(self.model.features)[1:]]
m_base = get_base() 
m = to_gpu(Upsample34(m_base))
models = UpsampleModel(m)
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]
learn.freeze_to(1)
learn.lr_find()
learn.sched.plot()
'''
86%|█████████████████████████████████████████████████████████████          | 55/64 [00:22<00:03,  2.46it/s, loss=3.21]
'''

lr=4e-2
wd=1e-7
lrs = np.array([lr/100,lr/10,lr])/2
learn.fit(lr,1, wds=wd, cycle_len=4,use_clr=(20,8))
'''
0%|          | 0/64 [00:00<?, ?it/s]
epoch      trn_loss   val_loss   <lambda>   dice           
    0      0.216882   0.133512   0.938017   0.855221  
    1      0.169544   0.115158   0.946518   0.878381       
    2      0.153114   0.099104   0.957748   0.903353       
    3      0.144105   0.093337   0.964404   0.915084
[0.09333742126112893, 0.9644036065964472, 0.9150839788573129]
'''
learn.save('tmp')
learn.load('tmp')
learn.unfreeze()
learn.bn_freeze(True)
learn.fit(lrs,1,cycle_len=4,use_clr=(20,8))
'''
epoch      trn_loss   val_loss   <lambda>   dice           
    0      0.174897   0.061603   0.976321   0.94382   
    1      0.122911   0.053625   0.982206   0.957624       
    2      0.106837   0.046653   0.985577   0.965792       
    3      0.099075   0.042291   0.986519   0.968925
[0.042291240323157536, 0.986519161670927, 0.9689251193924556]
'''

现在我们可以检查我们的 dice 指标[1:48:00]。所以你可以看到在 dice 指标上,我们在 128x128 处得到了大约 96.8。所以这不太好。

learn.save('128')
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))
show_img(py[0]>0);

show_img(y[0]);

U-net(ish)[1:48:16]

所以让我们尝试 U-Net。我称之为 U-net(ish),因为通常我正在创建自己的有点 hacky 版本——尽量保持与你习惯的东西尽可能相似,并做我认为有意义的事情。所以至少有很多机会让你至少通过查看确切的网格大小来使其更加真实地成为 U-net,看看这里(左上角的卷积)大小有点下降。所以显然他们没有添加任何填充,然后有一些裁剪——有一些差异。但其中一件事是因为我想利用迁移学习——这意味着我不能完全使用 U-Net。

所以另一个重要的机会是,如果你创建了 U-Net 的下行路径,然后在末尾添加一个分类器,然后在 ImageNet 上训练它。现在你有了一个在 ImageNet 上训练过的分类器,专门设计为 U-Net 的良好骨干。然后你应该能够回来并接近赢得这个旧竞赛(实际上并不是很旧——是一个相当新的竞赛)。因为以前不存在这种预训练网络。但是如果你想一下 YOLO v3 是如何做的,基本上就是这样。他们创建了一个 DarkNet,他们在 ImageNet 上预训练了它,然后他们将其用作边界框的基础。所以,再次强调这种不仅为分类而设计而且为其他事物而设计的预训练的想法——这是迄今为止没有人做过的事情。但正如我们所展示的,你现在可以用 25 美元在三小时内训练 ImageNet。如果社区中的人们对此感兴趣,希望我也能提供帮助,如果你愿意,我可以帮助你设置并给我一个脚本,我可能可以为你运行它。但目前我们还没有。所以我们将使用 ResNet。

class SaveFeatures():
    features=None
    def __init__(self, m):
        self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): 
        self.features = output
    def remove(self): 
        self.hook.remove()

所以我们基本上要从get_base开始[1:50:37]。Base 是我们的基础网络,这在第一部分中已经定义过了。

所以get_base将调用f是什么,fresnet34。所以我们将获取我们的 ResNet34 并且cut_model是我们的卷积网络构建器做的第一件事。它基本上删除了自适应池化之后的所有内容,这样我们就得到了 ResNet34 的骨干。所以get_base将给我们返回 ResNet34 的骨干。

class UnetBlock(nn.Module):
    def __init__(self, up_in, x_in, n_out):
        super().__init__()
        up_out = x_out = n_out//2
        self.x_conv  = nn.Conv2d(x_in,  x_out,  1)
        self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)
        self.bn = nn.BatchNorm2d(n_out)

    def forward(self, up_p, x_p):
        up_p = self.tr_conv(up_p)
        x_p = self.x_conv(x_p)
        cat_p = torch.cat([up_p,x_p], dim=1)
        return self.bn(F.relu(cat_p))class Unet34(nn.Module):
    def __init__(self, rn):
        super().__init__()
        self.rn = rn
        self.sfs = [SaveFeatures(rn[i]) for i in [2,4,5,6]]
        self.up1 = UnetBlock(512,256,256)
        self.up2 = UnetBlock(256,128,256)
        self.up3 = UnetBlock(256,64,256)
        self.up4 = UnetBlock(256,64,256)
        self.up5 = nn.ConvTranspose2d(256, 1, 2, stride=2)

    def forward(self,x):
        x = F.relu(self.rn(x))
        x = self.up1(x, self.sfs[3].features)
        x = self.up2(x, self.sfs[2].features)
        x = self.up3(x, self.sfs[1].features)
        x = self.up4(x, self.sfs[0].features)
        x = self.up5(x)
        return x[:,0]

    def close(self):
        for sf in self.sfs: 
            sf.remove()
class UnetModel():
    def __init__(self,model,name='unet'):
        self.model,self.name = model,name

    def get_layer_groups(self, precompute):
        lgs = list(split_by_idxs(children(self.model.rn), [lr_cut]))
        return lgs + [children(self.model)[1:]]

然后我们将把那个 ResNet34 主干转换成一个,我称之为 Unet34。因此,它将保存我们传入的 ResNet,然后我们将使用一个前向钩子,就像以前一样,在第 2、4、5 和 6 个块处保存结果,这些块是每个步幅 2 卷积之前的层。然后我们将创建一堆我们称之为UnetBlock的东西。我们需要告诉UnetBlock有多少东西来自我们要上采样的上一层,有多少来自交叉路径,然后我们想要输出多少。来自上一层的数量完全由基础网络定义——无论下行路径是什么,我们都需要那么多层。这有点尴尬。实际上我们这里的一个硕士学生,Kerem,实际上创建了一个叫做 DynamicUnet 的东西,你可以在fastai.model.DynamicUnet中找到,它实际上为你计算这一切,并自动从你的基础模型创建整个 Unet。它仍然有一些小问题,我想要修复。视频发布时,它肯定会正常工作,我至少会有一个展示如何使用它的笔记本,可能还有一个额外的视频。但现在你只能自己去做。一旦你有了一个 ResNet,你可以输入它的名称,它会打印出层。你可以看到每个块中有多少激活。或者你可以让它自动为每个块打印出来。无论如何,我只是手动做了这个。

所以 UnetBlock 的工作原理是这样的:

  • up_in:从上一层传入的数量

  • x_in:从下行路径传入的数量(因此x

  • n_out:我们想要输出的数量

现在我要做的是,然后我说,好的,我们将从上行路径创建一定数量的卷积,从交叉路径创建一定数量的卷积,所以我将它们连接在一起,所以让我们将我们想要的数量除以 2。因此,我们将让我们的交叉卷积从交叉路径中取出并除以 2(n_out//2)。然后上行路径将是ConvTranspose2d,因为我们想要增加/上采样。同样在这里,我们将我们想要的数量除以 2(up_out),然后最后,我只是将它们连接在一起。

所以我有一个上升样本,我有一个交叉卷积,我可以将这两者连接在一起。这就是 UnetBlock 的全部内容。所以这实际上是一个相当容易创建的模块。

然后在我的前向路径中,我需要将上升路径和交叉路径传递给 UnetBlock 的前向方法。上升路径只是到目前为止的任何事情。但是交叉路径是在下降过程中存储的激活。因此,当我上升时,我首先需要的是最后一组保存的特征。随着我逐渐向上走得更远,最终是第一组特征。

有一些更多的技巧可以让这个变得更好一点,但这已经是一个很好的东西了。所以简单的上采样方法看起来很糟糕,dice 值为 0.968。一个 Unet,除了现在我们有了这些 UnetBlocks 之外,其他一切都相同,dice 值为…

m_base = get_base()
m = to_gpu(Unet34(m_base))
models = UnetModel(m)
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]
learn.summary()
'''
OrderedDict([('Conv2d-1',
              OrderedDict([('input_shape', [-1, 3, 128, 128]),
                           ('output_shape', [-1, 64, 64, 64]),
                           ('trainable', False),
                           ('nb_params', 9408)])),
             ('BatchNorm2d-2',
              OrderedDict([('input_shape', [-1, 64, 64, 64]),
                           ('output_shape', [-1, 64, 64, 64]),
                           ('trainable', False),
                           ('nb_params', 128)])),
             ('ReLU-3',
              OrderedDict([('input_shape', [-1, 64, 64, 64]),
                           ('output_shape', [-1, 64, 64, 64]),
                           ('nb_params', 0)])),
             ('MaxPool2d-4',
              OrderedDict([('input_shape', [-1, 64, 64, 64]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-5',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 36864)])),
             ('BatchNorm2d-6',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 128)])),
             ('ReLU-7',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-8',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 36864)])),
             ('BatchNorm2d-9',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 128)])),
             ('ReLU-10',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('BasicBlock-11',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-12',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 36864)])),
             ('BatchNorm2d-13',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 128)])),
             ('ReLU-14',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-15',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 36864)])),
             ('BatchNorm2d-16',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 128)])),
             ('ReLU-17',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('BasicBlock-18',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-19',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 36864)])),
             ('BatchNorm2d-20',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 128)])),
             ('ReLU-21',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-22',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 36864)])),
             ('BatchNorm2d-23',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 128)])),
             ('ReLU-24',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('BasicBlock-25',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-26',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 73728)])),
             ('BatchNorm2d-27',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 256)])),
             ('ReLU-28',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-29',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 147456)])),
             ('BatchNorm2d-30',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 256)])),
             ('Conv2d-31',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 8192)])),
             ('BatchNorm2d-32',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 256)])),
             ('ReLU-33',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('BasicBlock-34',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-35',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 147456)])),
             ('BatchNorm2d-36',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 256)])),
             ('ReLU-37',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-38',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 147456)])),
             ('BatchNorm2d-39',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 256)])),
             ('ReLU-40',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('BasicBlock-41',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-42',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 147456)])),
             ('BatchNorm2d-43',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 256)])),
             ('ReLU-44',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-45',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 147456)])),
             ('BatchNorm2d-46',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 256)])),
             ('ReLU-47',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('BasicBlock-48',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-49',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 147456)])),
             ('BatchNorm2d-50',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 256)])),
             ('ReLU-51',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-52',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 147456)])),
             ('BatchNorm2d-53',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 256)])),
             ('ReLU-54',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('BasicBlock-55',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-56',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 294912)])),
             ('BatchNorm2d-57',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-58',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-59',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-60',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('Conv2d-61',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 32768)])),
             ('BatchNorm2d-62',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-63',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-64',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-65',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-66',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-67',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-68',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-69',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-70',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-71',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-72',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-73',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-74',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-75',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-76',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-77',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-78',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-79',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-80',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-81',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-82',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-83',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-84',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-85',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-86',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-87',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-88',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-89',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-90',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-91',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-92',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-93',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-94',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-95',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-96',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-97',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-98',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-99',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-100',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 1179648)])),
             ('BatchNorm2d-101',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 1024)])),
             ('ReLU-102',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('Conv2d-103',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 2359296)])),
             ('BatchNorm2d-104',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 1024)])),
             ('Conv2d-105',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 131072)])),
             ('BatchNorm2d-106',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 1024)])),
             ('ReLU-107',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('BasicBlock-108',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('Conv2d-109',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 2359296)])),
             ('BatchNorm2d-110',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 1024)])),
             ('ReLU-111',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('Conv2d-112',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 2359296)])),
             ('BatchNorm2d-113',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 1024)])),
             ('ReLU-114',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('BasicBlock-115',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('Conv2d-116',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 2359296)])),
             ('BatchNorm2d-117',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 1024)])),
             ('ReLU-118',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('Conv2d-119',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 2359296)])),
             ('BatchNorm2d-120',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 1024)])),
             ('ReLU-121',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('BasicBlock-122',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('ConvTranspose2d-123',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 128, 8, 8]),
                           ('trainable', True),
                           ('nb_params', 262272)])),
             ('Conv2d-124',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 128, 8, 8]),
                           ('trainable', True),
                           ('nb_params', 32896)])),
             ('BatchNorm2d-125',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', True),
                           ('nb_params', 512)])),
             ('UnetBlock-126',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('ConvTranspose2d-127',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', True),
                           ('nb_params', 131200)])),
             ('Conv2d-128',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', True),
                           ('nb_params', 16512)])),
             ('BatchNorm2d-129',
              OrderedDict([('input_shape', [-1, 256, 16, 16]),
                           ('output_shape', [-1, 256, 16, 16]),
                           ('trainable', True),
                           ('nb_params', 512)])),
             ('UnetBlock-130',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 16, 16]),
                           ('nb_params', 0)])),
             ('ConvTranspose2d-131',
              OrderedDict([('input_shape', [-1, 256, 16, 16]),
                           ('output_shape', [-1, 128, 32, 32]),
                           ('trainable', True),
                           ('nb_params', 131200)])),
             ('Conv2d-132',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 128, 32, 32]),
                           ('trainable', True),
                           ('nb_params', 8320)])),
             ('BatchNorm2d-133',
              OrderedDict([('input_shape', [-1, 256, 32, 32]),
                           ('output_shape', [-1, 256, 32, 32]),
                           ('trainable', True),
                           ('nb_params', 512)])),
             ('UnetBlock-134',
              OrderedDict([('input_shape', [-1, 256, 16, 16]),
                           ('output_shape', [-1, 256, 32, 32]),
                           ('nb_params', 0)])),
             ('ConvTranspose2d-135',
              OrderedDict([('input_shape', [-1, 256, 32, 32]),
                           ('output_shape', [-1, 128, 64, 64]),
                           ('trainable', True),
                           ('nb_params', 131200)])),
             ('Conv2d-136',
              OrderedDict([('input_shape', [-1, 64, 64, 64]),
                           ('output_shape', [-1, 128, 64, 64]),
                           ('trainable', True),
                           ('nb_params', 8320)])),
             ('BatchNorm2d-137',
              OrderedDict([('input_shape', [-1, 256, 64, 64]),
                           ('output_shape', [-1, 256, 64, 64]),
                           ('trainable', True),
                           ('nb_params', 512)])),
             ('UnetBlock-138',
              OrderedDict([('input_shape', [-1, 256, 32, 32]),
                           ('output_shape', [-1, 256, 64, 64]),
                           ('nb_params', 0)])),
             ('ConvTranspose2d-139',
              OrderedDict([('input_shape', [-1, 256, 64, 64]),
                           ('output_shape', [-1, 1, 128, 128]),
                           ('trainable', True),
                           ('nb_params', 1025)]))])
'''
[o.features.size() for o in m.sfs]
'''
[torch.Size([3, 64, 64, 64]),
 torch.Size([3, 64, 32, 32]),
 torch.Size([3, 128, 16, 16]),
 torch.Size([3, 256, 8, 8])]
'''
learn.freeze_to(1)learn.lr_find()
learn.sched.plot()
''' 0%|                                                                                           | 0/64 [00:00<?, ?it/s]92%|█████████████████████████████████████████████████████████████████▍     | 59/64 [00:22<00:01,  2.68it/s, loss=2.45]
'''

lr=4e-2
wd=1e-7

lrs = np.array([lr/100,lr/10,lr])
learn.fit(lr,1,wds=wd,cycle_len=8,use_clr=(5,8))
'''
epoch      trn_loss   val_loss   <lambda>   dice           
    0      0.12936    0.03934    0.988571   0.971385  
    1      0.098401   0.039252   0.990438   0.974921        
    2      0.087789   0.02539    0.990961   0.978927        
    3      0.082625   0.027984   0.988483   0.975948        
    4      0.079509   0.025003   0.99171    0.981221        
    5      0.076984   0.022514   0.992462   0.981881        
    6      0.076822   0.023203   0.992484   0.982321        
    7      0.075488   0.021956   0.992327   0.982704
[0.021955982234979434, 0.9923273126284281, 0.9827044502137199]
'''
learn.save('128urn-tmp')
learn.load('128urn-tmp')
learn.unfreeze()
learn.bn_freeze(True)
learn.fit(lrs/4, 1, wds=wd, cycle_len=20,use_clr=(20,10))
'''
0%|          | 0/64 [00:00<?, ?it/s]
epoch      trn_loss   val_loss   <lambda>   dice            
    0      0.073786   0.023418   0.99297    0.98283   
    1      0.073561   0.020853   0.992142   0.982725        
    2      0.075227   0.023357   0.991076   0.980879        
    3      0.074245   0.02352    0.993108   0.983659        
    4      0.073434   0.021508   0.993024   0.983609        
    5      0.073092   0.020956   0.993188   0.983333        
    6      0.073617   0.019666   0.993035   0.984102        
    7      0.072786   0.019844   0.993196   0.98435         
    8      0.072256   0.018479   0.993282   0.984277        
    9      0.072052   0.019479   0.993164   0.984147        
    10     0.071361   0.019402   0.993344   0.984541        
    11     0.070969   0.018904   0.993139   0.984499        
    12     0.071588   0.018027   0.9935     0.984543        
    13     0.070709   0.018345   0.993491   0.98489         
    14     0.072238   0.019096   0.993594   0.984825        
    15     0.071407   0.018967   0.993446   0.984919        
    16     0.071047   0.01966    0.993366   0.984952        
    17     0.072024   0.018133   0.993505   0.98497         
    18     0.071517   0.018464   0.993602   0.985192        
    19     0.070109   0.018337   0.993614   0.9852
[0.018336569653853538, 0.9936137114252362, 0.9852004420189631]
'''

0.985!这就像我们将错误减半,其他一切完全相同。而且更重要的是,你可以看一下。

learn.save('128urn-0')
learn.load('128urn-0')
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))

与我们的非 Unet 等效物相比,这实际上看起来有点像汽车,后者只是一个斑点。因为试图通过下行和上行路径来做这个——这只是要求太多了。而当我们实际上在每个点提供下行路径像素时,它实际上可以开始创建一些类似汽车的东西。

show_img(py[0]>0);

show_img(y[0]);

最后,我们将执行 m.close 以删除占用 GPU 内存的sfs.features

m.close()

512x512 [1:56:26]

转到较小的批量大小,更高的大小

sz=512
bs=16
tfms = tfms_from_model(
    resnet34, sz, 
    crop_type=CropType.NO, 
    tfm_y=TfmType.CLASS, 
    aug_tfms=aug_tfms
)
datasets = ImageData.get_ds(
    MatchedFilesDataset, 
    (trn_x,trn_y), 
    (val_x,val_y), 
    tfms, 
    path=PATH
)
md = ImageData(PATH, datasets, bs, num_workers=4, classes=None)
denorm = md.trn_ds.denormm_base = get_base()
m = to_gpu(Unet34(m_base))
models = UnetModel(m)
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]
learn.freeze_to(1)
learn.load('128urn-0')
learn.fit(lr,1,wds=wd, cycle_len=5,use_clr=(5,5))
'''
epoch      trn_loss   val_loss   <lambda>   dice              
    0      0.071421   0.02362    0.996459   0.991772  
    1      0.070373   0.014013   0.996558   0.992602          
    2      0.067895   0.011482   0.996705   0.992883          
    3      0.070653   0.014256   0.996695   0.992771          
    4      0.068621   0.013195   0.996993   0.993359
[0.013194938530288046, 0.996993034604996, 0.993358936574724]
'''

你可以看到 Dice 系数真的在上升[1:56:30]。所以请注意,我正在加载网络的 128x128 版本。我们再次使用渐进式调整大小的技巧,这样我们得到了 0.993。

learn.save('512urn-tmp')
learn.unfreeze()
learn.bn_freeze(True)
learn.load('512urn-tmp')
learn.fit(lrs/4,1,wds=wd, cycle_len=8,use_clr=(20,8))
'''
epoch      trn_loss   val_loss   <lambda>   dice              
    0      0.06605    0.013602   0.997      0.993014  
    1      0.066885   0.011252   0.997248   0.993563          
    2      0.065796   0.009802   0.997223   0.993817          
    3      0.065089   0.009668   0.997296   0.993744          
    4      0.064552   0.011683   0.997269   0.993835          
    5      0.065089   0.010553   0.997415   0.993827          
    6      0.064303   0.009472   0.997431   0.994046          
    7      0.062506   0.009623   0.997441   0.994118
[0.009623114736602894, 0.9974409020136273, 0.9941179137381296]
'''

然后解冻以达到 0.994。

learn.save('512urn')
learn.load('512urn')
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))

你可以看到,现在看起来很不错。

show_img(py[0]>0);

show_img(y[0]);

m.close()

1024x1024 [1:56:53]

将批量大小降至 4,大小为 1024。

sz=1024
bs=4
tfms = tfms_from_model(
    resnet34, sz, 
    crop_type=CropType.NO, 
    tfm_y=TfmType.CLASS
)
datasets = ImageData.get_ds(
    MatchedFilesDataset, 
    (trn_x,trn_y), 
    (val_x,val_y), 
    tfms, 
    path=PATH
)
md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)
denorm = md.trn_ds.denormm_base = get_base()
m = to_gpu(Unet34(m_base))
models = UnetModel(m)
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]

加载我们刚刚保存的 512。

learn.load('512urn')
learn.freeze_to(1)
learn.fit(lr,1, wds=wd, cycle_len=2,use_clr=(5,4))
'''
epoch      trn_loss   val_loss   <lambda>   dice                 
    0      0.007656   0.008155   0.997247   0.99353   
    1      0.004706   0.00509    0.998039   0.995437
[0.005090427414942828, 0.9980387706605215, 0.995437301104031]
'''

这让我们达到了 0.995。

learn.save('1024urn-tmp')
learn.load('1024urn-tmp')
learn.unfreeze()
learn.bn_freeze(True)
lrs = np.array([lr/200,lr/30,lr])
learn.fit(lrs/10,1, wds=wd,cycle_len=4,use_clr=(20,8))
'''
epoch      trn_loss   val_loss   <lambda>   dice                 
    0      0.005688   0.006135   0.997616   0.994616  
    1      0.004412   0.005223   0.997983   0.995349             
    2      0.004186   0.004975   0.99806    0.99554              
    3      0.004016   0.004899   0.99812    0.995627
[0.004898778487196458, 0.9981196409180051, 0.9956271404784823]
'''
learn.fit(lrs/10,1, wds=wd,cycle_len=4,use_clr=(20,8))
'''
epoch      trn_loss   val_loss   <lambda>   dice                 
    0      0.004169   0.004962   0.998049   0.995517  
    1      0.004022   0.004595   0.99823    0.995818             
    2      0.003772   0.004497   0.998215   0.995916             
    3      0.003618   0.004435   0.998291   0.995991
[0.004434524739663753, 0.9982911745707194, 0.9959913929776539]
'''

解冻将我们带到...我们将称之为 0.996。

learn.sched.plot_loss()

learn.save('1024urn')
learn.load('1024urn')
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))

正如你所看到的,实际上看起来很不错[1:57:17]。在准确性方面,99.82%。你可以看到这看起来像是你可以用来裁剪的东西。我认为,在这一点上,我们可以做一些微小的调整来达到 0.997,但真正的关键是,我认为,也许只需要做一些平滑处理或一点后处理。你可以去看看 Carvana 获奖者的博客,看看其中的一些技巧,但正如我所说,我们目前的 0.996 和获奖者得到的 0.997 之间的差距并不大。所以实际上,U-Net 基本上解决了这个问题。

show_img(py[0]>0);

show_img(y[0]);

回到边界框[1:58:15]

好的,就是这样。我想要提到的最后一件事是现在回到边界框,因为你可能还记得,我说我们的边界框模型在小物体上仍然表现不佳。所以希望你能猜到我接下来要做什么,那就是对于边界框模型,记得我们在不同的网格单元中输出了模型的输出。那些较早的具有较小网格大小的输出并不好。我们该如何修复呢?用 U-Net!让我们有一个带有交叉连接的向上路径。然后我们将使用 U-Net,然后从中输出。因为现在那些更精细的网格单元具有该路径的所有信息,以及该路径、该路径和该路径的信息。当然,这是深度学习,这意味着你不能写一篇论文说我们只是用 U-Net 来处理边界框。你必须发明一个新词,所以这被称为特征金字塔网络或 FPNs。这在 RetinaNet 论文中使用过,它是在早期关于 FPNs 的论文中创建的。如果我记得正确的话,他们确实简要引用了 U-Net 论文,但他们似乎让它听起来像是这个模糊地稍微相关的东西,也许有些人可能认为稍微有用。但实际上,FPNs 就是 U-Nets。

我没有实现它来展示给你,但这将是一件有趣的事情,也许对于我们中的一些人来尝试,我知道一些学生一直在尝试在论坛上使其良好运行。所以是的,尝试一下是有趣的事情。所以我认为在这堂课之后要看的一些事情,以及我提到的其他事情,可能是玩玩 FPNs,也可能尝试一下 Kerem 的 DynamicUnet。它们都是值得一看的有趣的东西。

所以你们现在已经经历了我对你们讲解的 14 堂课。对此我感到抱歉。谢谢你们忍受我。我认为你们会发现很难找到其他人对神经网络训练和实践了解得像你们这样多。你们很容易高估其他人的能力,低估自己的能力。所以我想说的是,请继续练习。因为现在没有每个星期一晚上都有我在这里让你们回来了。很容易失去动力。所以找到方法保持下去。组织一个学习小组,一个读书小组,或者和朋友们一起做项目,或者做一些不仅仅是决定我要继续做 X 的事情。除非你是那种超级有动力的人,每当你决定做某事,它就会发生。那不是我。我知道,要让事情发生,我必须说“是的,大卫。十月份,我绝对会教那门课程”,然后我就得开始写一些材料。这是我让事情发生的唯一方法。所以我们在论坛上有一个很棒的社区。如果有人有想法让它变得更好,请告诉我。如果你认为你可以帮忙,如果你想创建一些新的论坛或以某种不同的方式进行管理,或者其他什么的,只要告诉我。你可以随时私信我,GitHub 上也有很多项目正在进行中——很多东西。所以我希望能在其他地方再见到你们,非常感谢你们加入我的旅程。