FFT实现带步长的卷积

156 阅读3分钟

FFT实现卷积的流程

对于不带步长的卷积操作,其实相对也算简单

  1. 对卷积核和卷积输入进行填充, 填充到2的次幂
  2. 对卷积核和卷积数据进FFTFFT
  3. 将卷积核和卷积数据相乘
  4. 对乘法的结果进行IFFTIFFT

总的来说,在写好FFTFFT 函数之后这些都算比较简单的

问题来了。 如果带步长呢?

实际上最简单的方法, 对于这种带步长的,就用步长为1的计算之后再间隔步长选取就行了。但是这种时间复杂度么。。。呵呵。。。如果步长足够长, 传统卷积也是很容易超过fft进行卷积,fft本身更适合步长为1的卷积

#include<iostream>
#include<algorithm>
#include<cmath>

const double Pi=3.1415926535897932;

// 实现一个复数类
struct complex{
public:
    double x;
    double y;
public:
    complex(){x = y = 0;}
    complex(double xx, double yy){x = xx, y = yy;}
    complex(const complex& elm){x = elm.x, y = elm.y;};
    friend std::ostream& operator<<(std::ostream& os, const complex& c) 
    {return os << "(" << c.x << ", " << c.y << "i)";}
    complex& operator=(const complex& elm)
    {x = elm.x, y = elm.y; return *this;}
    complex& operator=(const complex&& elm)
    {x = elm.x, y = elm.y; return *this;}
    complex operator +(complex elm)
    { return complex(x + elm.x, y + elm.y);}
    complex operator -(complex elm)
    { return complex(x - elm.x, y - elm.y);}
    complex operator *(complex elm)
    {return complex(x * elm.x - y * elm.y, x * elm.y + y * elm.x);}
};

inline void FFT(complex* x, int len, int sgn) {
    int n = 0;
    while ((1 << n) < len)n++;
    int* bit = new int[len];
    bit[0] = 0; 
    for (int i = 0; i < len; i++) {
        bit[i] = (bit[i>>1]>>1) | ((i & 1) << (n - 1));
        if(i < bit[i])std::swap(x[i], x[bit[i]]);
    }
    delete[] bit;
    for (int i = 0; i < n; i++) {
        int half_block_size = 1 << i;
        int delta_step = half_block_size << 1;
        double theta = Pi /  half_block_size;
        complex Wn(cos(theta), sgn * sin(theta));
        for (int block = 0; block < len; block += delta_step) {
            complex w(1, 0);
            for (int step = 0; step < half_block_size; step++) {
                int i0 = block + step;
                int i1 = block + step + half_block_size;
                complex c0 = x[i0];
                complex c1 = w * x[i1];
                x[i0] = c0 + c1;
                x[i1] = c0 - c1;
                w = w * Wn;
            }
        }
    }
}

inline void conv_1d(double* input, double* kernel, double* output, int input_size, int ker_size, int step=1) {
    int output_size = (input_size - ker_size) / step + 1;
    for (int i = 0; i < output_size; i++) {
        int pos = i * step;
        double sum = 0.0;
        for (int k = 0; k < ker_size; k++) {
            int pos_ori = pos - k;
            if (pos_ori >= 0 && pos_ori < input_size) {
                sum += input[pos_ori] * kernel[k];
            }
        }
        output[i] = sum; 
    }
}

inline void fft_conv_1d(double* input, double* kernel, double* output, int input_size, int ker_size, int step=1) {
    int lim_size = 1;
    int tgt_size = (input_size - ker_size) / step +  1;
    while (input_size + ker_size > lim_size) lim_size <<= 1;
    complex* cp_dat = new complex[lim_size];
    complex* cp_ker = new complex[lim_size];

    for(int i=0;i<input_size;i++)cp_dat[i].x = input[i];
    for(int i=0;i<ker_size;i++)cp_ker[i].x = kernel[i];
    
    FFT(cp_dat, lim_size, 1);
    FFT(cp_ker, lim_size, 1);
    for (int i = 0; i < lim_size; i++) 
        cp_dat[i] = cp_dat[i] * cp_ker[i];
    FFT(cp_dat, lim_size, -1);

    for (int i = 0; i < tgt_size; i++) 
        output[i] = cp_dat[i * step].x / lim_size;
    
    delete[] cp_dat;
    delete[] cp_ker;
}




void conv_test(){
    const int dat_size = 64000;
    const int ker_size = 8000;
    const int step = 2;
    const int tgt_size = (dat_size - ker_size) / step + 1;
    const double eps = 1e-4;
    double* dat = new double[dat_size];
    double* ker = new double[ker_size];
    double* tgt0 = new double[tgt_size];
    double* tgt1 = new double[tgt_size];

    srand((unsigned int)time(NULL));
    for(int i=0;i<dat_size;i++)dat[i] = (rand() % 100) / 10.0;
    for(int i=0;i<ker_size;i++)ker[i] = (rand() % 100) / 10.0;

    int start_time, end_time;
    start_time = clock();
    fft_conv_1d(dat, ker, tgt0, dat_size, ker_size, step);
    end_time = clock();
    printf("fft_conv_1d using %d ms\n", end_time - start_time);
    start_time = clock();
    conv_1d(dat, ker, tgt1, dat_size, ker_size, step);
    end_time = clock();
    printf("conv_1d using %d ms\n", end_time - start_time);

    int cnt = 0;
    for(int i=0;i<tgt_size;i++){
        if(fabs(tgt0[i] - tgt1[i]) > eps)
            cnt++;
    }
    printf("over eps count: %d\n", cnt);
    delete[] dat;
    delete[] ker;
    delete[] tgt0;
    delete[] tgt1;
}
int main(){ 
    conv_test();
}

就上述参数本机测试结果

fft_conv_1d using 67 ms
conv_1d using 363 ms
over eps count: 0

对于二维的卷积,可以拆分卷积核然后循环使用一维的卷积