深度视觉中有关图像projection的代码改写cv2.remap() → F.grid_sample() | Numpy+cv2格式改为PyTorch格式

351 阅读1分钟

Cv2 remap in pytorch?

Numpy+cv2实现的代码迁移到PyTorch上往往不怎么需要改动,直接把np换成torch即可,但cv2.remap()函数是个特殊例子,该函数通过xy两个数组重新采样图像,可以用来实现投影变换(warp,projection),在torch中与之对应的是torch.nn.functional.grid_sample() 函数,但用法上有着一些不同。

以下以我的一个实际代码片段例子来直观介绍torch版本的代码重写。

我的任务是将ref视点的ref_img and ref_depth 投影到另一个src视点。

Numpycv2 风格的代码:

def reproject_with_depth(img_ref, depth_ref, intrinsics_ref, extrinsics_ref, intrinsics_src, extrinsics_src):
    width, height = depth_ref.shape[1], depth_ref.shape[0]

    x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height))
    x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1])

    xyz_ref = np.matmul(np.linalg.inv(intrinsics_src),
                        np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1]))

    xyz_src = np.matmul(np.matmul(extrinsics_ref, np.linalg.inv(extrinsics_src)),
                        np.vstack((xyz_ref, np.ones_like(x_ref))))[:3]

    K_xyz_src = np.matmul(intrinsics_ref, xyz_src)
    xy_src = K_xyz_src[:2] / K_xyz_src[2:3]

    x_src = xy_src[0].reshape([height, width]).astype(np.float32)
    y_src = xy_src[1].reshape([height, width]).astype(np.float32)

    sampled_depth_src = cv2.remap(depth_ref, x_src, y_src, interpolation=cv2.INTER_LINEAR)
    sampled_img_src = cv2.remap(img_ref, x_src, y_src, interpolation=cv2.INTER_LINEAR)

    return sampled_depth_src, sampled_img_src

翻译成 torch 风格后的代码:

def reproject_with_depth(img_ref, depth_ref, intrinsics_ref, extrinsics_ref, intrinsics_src, extrinsics_src):
    B, width, height = depth_ref.shape[0], depth_ref.shape[2], depth_ref.shape[1]

    y_ref, x_ref = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=depth_ref.device), torch.arange(0, width, dtype=torch.float32, device=depth_ref.device)])
    y_ref, x_ref = y_ref.contiguous(), x_ref.contiguous()
    x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1])
    # reference 3D space
    xyz_ref = torch.matmul(torch.inverse(intrinsics_src),
                        torch.stack((x_ref, y_ref, torch.ones_like(x_ref))).unsqueeze(0).repeat(B, 1, 1) * depth_ref.reshape([B, 1, -1]))

    xyz_src = torch.matmul(torch.matmul(extrinsics_ref, torch.inverse(extrinsics_src)),
                        torch.cat([xyz_ref, torch.ones_like(x_ref).unsqueeze(0).repeat(B,1,1)], dim=1))[:,:3]

    K_xyz_src = torch.matmul(intrinsics_ref, xyz_src)
    xy_src = K_xyz_src[:, :2] / K_xyz_src[:, 2:3]

    x_src = xy_src[:, 0].reshape([B, height, width]).float()
    y_src = xy_src[:, 1].reshape([B, height, width]).float()

    grid = torch.stack((x_src/((width-1)/2)-1, y_src/((height-1)/2)-1), dim=3)
    sampled_depth_src = F.grid_sample(depth_ref.unsqueeze(1), grid.view(B, height, width, 2), mode='bilinear', padding_mode='zeros').squeeze(1)
    sampled_img_src = F.grid_sample(img_ref, grid.view(B, height, width, 2), mode='bilinear', padding_mode='zeros')

    return sampled_depth_src, sampled_img_src

一些核心翻译原则需要遵守的(可能会给你带来困惑的)是:

  • torch版的代码要考虑 batch B这个维度,因此诸如切片等操作记得先把0-dim考虑进去,例如 [:, idx]
  • torch版代码需要考虑device ,一般可以另新创建的tensor的dtype为输入参数的dtype
  • F.grid_sample()的坐标采样范围是[-1, 1],而cv2.remap() 直接使用的是像素坐标尺度,因此需要在x轴/((width-1)/2)-1,在y轴/((height-1)/2)-1) 来缩放坐标系