从DFT到离散余弦变换DCT,以及DCT的PyTorch实现

282 阅读2分钟

总览

离散余弦变换,DCT。DCT 在图像视频音频压缩领域用得比 FFT 更多。由于其能量更集中的特性,适合拿来裁剪不重要的信息。

本文主要通过 DFT 到 DCT 的推导加深对 DCT 的理解。

最后给出了 DCT-II 的公式,以及一个 PyTorch 代码实现。

离散傅里叶变换 DFT

傅里叶变换。

X(ω)=x(t)ejωtdtX(\omega)= \int^\infin_{-\infin}x(t)·e^{-j\omega t}\text{d}t

离散傅里叶变换。

X[k]=n=0N1x[n]ej(2πk/N)nX[k]=\sum^{N-1}_{n=0}x[n]e^{-j(2\pi k/N)·n}

离散傅里叶变换借助欧拉公式,展开为两项。

X[k]=n=0N1x[n](cos2πknNjsin2πknN)X[k]=\sum^{N-1}_{n=0}x[n] \left( \cos\frac{2\pi kn}{N}- j\sin\frac{2\pi kn}{N} \right)

记实部和虚部:

Re[k]=n=0N1x[n]cos2πknNIm[k]=n=0N1x[n]sin2πknN\begin{aligned} &\text{Re}[k]=\sum^{N-1}_{n=0}x[n]\cos\frac{2\pi kn}{N}\\ &\text{Im}[k]=\sum^{N-1}_{n=0}x[n]\sin\frac{2\pi kn}{N} \end{aligned}

从 DFT 到 DCT

离散傅里叶变换中有两个特点很明了:实部 Re[k]\text{Re}[k] 项是偶函数,虚部 Im[k]\text{Im}[k] 项是奇函数。所以,若原信号 x[n]x[n]全实数的偶函数信号,那么虚部 Im[k]\text{Im}[k] 是可以省略的,如此便能抛弃虚部,能简化很多问题。

离散余弦变换就是利用了这个特性。虽然不能不负责任地假设一切实信号都是偶函数的,但可以通过延拓来创造一个偶函数。

如何构造?

设现有一长度为 NN 的实数离散信号 x[n]x[n],其中 0nN10\le n\le N-1。现构造一个 x[m]x'[m],满足

x[m]={x[m]0mN1x[m1]Nm1x'[m]= \begin{cases} x[m]& \quad 0\le m\le N-1\\ x[-m-1]& -N\le m\le -1 \end{cases}

这个新信号关于 m=0.5m=-0.5 对称。只要往右平移 0.5,就是一个标准的实偶信号。借此写出该信号的傅里叶变换式,并使用对称性质和 n=m0.5n=m-0.5 进行变换和化简。

X[k]=m=N+0.5N0.5x[m0.5]ej(2πk/2N)m=2n=0N0.5x[m0.5]ej(πk/N)m=2n=0N1x[n]ej(πk/N)(n+0.5)\begin{aligned} X[k] &=\sum^{N-0.5}_{m=-N+0.5}x'[m-0.5]e^{-j(2\pi k/2N)·m}\\ &=2·\sum^{N-0.5}_{n=0}x'[m-0.5]e^{-j(\pi k/N)·m}\\ &=2·\sum^{N-1}_{n=0}x'[n]e^{-j(\pi k/N)(n+0.5)}\\ \end{aligned}

没有了虚数项 Im[k]\text{Im}[k] 的负担,式子的变换非常轻松。现在只保留实部,写出最终的式子:

X[k]=2n=0N1x[n]cosπk(n+0.5)NX[k]=2·\sum^{N-1}_{n=0}x[n]\cos\frac{\pi k(n+0.5)}{N}

离散余弦变换 DCT

最常用的 DCT-II 形式如下。

Xk=s(k)n=0N1xncosπk(n+0.5)NX_k=s(k)\sum^{N-1}_{n=0}x_n\cos\frac{\pi k(n+0.5)}{N}

其中 s(0)=1Ns(0)=\sqrt{\frac{1}{N}},其他情况 s(k)=2Ns(k)=\sqrt{\frac{2}{N}}。这样做主要是因为在工程应用上让矩阵正交。

2d-DCT 与代码实现

PyTorch 有原生的 FFT 实现但没有 DCT 的。在看论文 (2025) Mesoscopic Insights: Orchestrating Multi-scale & Hybrid Architecture for Image Manipulation Localization 的对应代码时看到的一段 2d-DCT 实现。大致如下。

class DCT(nn.Module):
    def __init__(self):
        super(self).__init__()
        self.dct_matrix_h = None
        self.dct_matrix_w = None

    def create_dct_matrix(self, N):
        n = torch.arange(N, dtype=torch.float32).reshape((1, N))
        k = torch.arange(N, dtype=torch.float32).reshape((N, 1))
        dct_matrix = torch.sqrt(torch.tensor(2.0 / N)) * torch.cos(math.pi * k * (2 * n + 1) / (2 * N))
        dct_matrix[0, :] = 1 / math.sqrt(N)
        return dct_matrix

    def dct_2d(self, x):
        H, W = x.size(-2), x.size(-1)
        if self.dct_matrix_h is None or self.dct_matrix_h.size(0) != H:
            self.dct_matrix_h = self.create_dct_matrix(H).to(x.device)
        if self.dct_matrix_w is None or self.dct_matrix_w.size(0) != W:
            self.dct_matrix_w = self.create_dct_matrix(W).to(x.device)
        
        return torch.matmul(self.dct_matrix_h, torch.matmul(x, self.dct_matrix_w.t()))

    def idct_2d(self, x):
        H, W = x.size(-2), x.size(-1)
        if self.dct_matrix_h is None or self.dct_matrix_h.size(0) != H:
            self.dct_matrix_h = self.create_dct_matrix(H).to(x.device)
        if self.dct_matrix_w is None or self.dct_matrix_w.size(0) != W:
            self.dct_matrix_w = self.create_dct_matrix(W).to(x.device)
        
        return torch.matmul(self.dct_matrix_h.t(), torch.matmul(x, self.dct_matrix_w))

挺直接的两次矩阵乘法。

参考来源