PyTorch实现DepthToSpace算子

582 阅读1分钟

PyTorch实现DepthToSpace算子

DCR模式

import torch
def depth_to_space_dcr(input_tensor, block_size):
    # 输入张量的形状应为 (N, C, H, W)
    N, C, H, W = input_tensor.shape

    # 计算输出的通道数
    output_channels = C // (block_size ** 2)

    # 重塑输入张量
    input_reshaped = input_tensor.view(N, block_size, block_size, output_channels, H, W)

    # 转置维度
    transposed = input_reshaped.permute(0, 3, 4, 1, 5, 2)

    # 合并维度
    output_tensor = transposed.contiguous().view(N, output_channels, H * block_size, W * block_size)

    return output_tensor

CRD模式

import torch
def depth_to_space_crd(input_tensor, block_size):
    # 输入张量的形状应为 (N, C, H, W)
    N, C, H, W = input_tensor.shape

    # 计算输出的通道数
    output_channels = C // (block_size ** 2)

    # 重塑输入张量
    input_reshaped = input_tensor.view(N, output_channels, block_size, block_size, H, W)

    # 转置维度,将列和行放在前面
    transposed = input_reshaped.permute(0, 1, 4, 2, 5, 3)

    # 合并维度
    output_tensor = transposed.contiguous().view(N, output_channels, H * block_size, W * block_size)

    return output_tensor

二者区别主要在重塑部分不一致