C++ 实现FFT算法

365 阅读4分钟

1. 数学原理

欧拉公式: ωnk=cos2kπn+isin2kπn\omega_{n}^{k}=\cos\frac{2k\pi}{n}+i\sin\frac{2k\pi}{n}

ωnn=ωn0=1\omega^n_n = \omega^0_n =1, wnn2=1w^{\frac{n}{2}}_n = -1

如果从多项式乘法的角度理解,由多项式分解规律可得

A(x)=A1(x2)+xA2(x2)A(x) = A_1(x^2) + xA_2(x^2)

A(ωnk)=A1(wn2k)+ωnkA2(wn2k)A(\omega^k_n) = A_1(w^{2k}_n) + \omega^k_nA_2(w^{2k}_n)

A(ωnk+n2)=A1(wn2k+n)+ωnk+n2A2(wn2k+n)A(\omega^{k+\frac{n}{2}}_n) = A_1(w^{2k+n}_n) + \omega^{k+\frac{n}{2}}_nA_2(w^{2k+n}_n)

=A1(wn2k+n)ωnkA2(wn2k+n)= A_1(w^{2k+n}_n) - \omega^k_n A_2(w^{2k+n}_n)

=A1(wn2k)ωnkA2(wn2k)= A_1(w^{2k}_n) - \omega^k_n A_2(w^{2k}_n)

2. 代码实现

1. 递归版本的fft

可以得到递归版本的fft

void fft(complex*x, int size, int sgn){
    if(size <= 1){
        return;
    }
    int half_size = size / 2;
    complex* a0 = new complex[half_size];
    complex* a1 = new complex[half_size];
    for(int i=0;i<half_size;i++){
        a0[i] = x[2*i];
        a1[i] = x[2*i+1];
    }
    fft(a0, half_size, sgn);
    fft(a1, half_size, sgn);
    double theta = 2*Pi/size;
    complex Wn = complex(cos(theta),  sgn * sin(theta));
    complex w = complex(1.0, 0);
    for(int i=0;i<half_size;i++){
        x[i] = a0[i] + w * a1[i];
        x[i+half_size] = a0[i] - w * a1[i];
        w = w * Wn;
    }
    delete[] a0;
    delete[] a1;
}
2. 非递归版本的fft

如果我们要得到非递归版本的,其实有一定程度上的问题

由于每次都是递归选取偶数位置,这里的位置是有一定规律的, 比如一个八个元素的fft顺序变换如下

01234567
0246 | 1357
04 | 26 | 15 | 37
0 | 4 | 2 | 6 | 1 | 5 | 3 | 7

000  000 0
001  100 4
010  010 2
011  110 6
100  001 1
101  101 5
110  011 3
111  111 7

如何求得这个序列呢,这个东西其实就是二进制编码的翻转形式 , 通过下面的代码可以实现

bit[0] = 0; 
int n = log2(size);
for (int i = 0; i < size; i++) {
    bit[i] = (bit[i>>1]>>1) | ((i & 1) << (n - 1));
}

(i & 1) << (n - 1)) 其实是把i最后一位移动到首位

在求得这个序列之后,可以知道二进制编码相反的数字在一个交换环上,就可以交换得到打乱的顺序。通过这些操作之后可以得到一个非递归版本的FFTIFFTFFT-IFFT 函数

inline void FFT(complex* x, int size, int sgn) {
    int n = 0;
    while ((1 << n) < size)n++;
    int* bit = new int[size];
    bit[0] = 0; 
    for (int i = 0; i < size; i++) {
        bit[i] = (bit[i>>1]>>1) | ((i & 1) << (n - 1));
        if(i < bit[i])std::swap(x[i], x[bit[i]]);
    }
    for (int i = 0; i < n; i++) {
        int half_block_size = 1 << i;
        int delta_step = half_block_size << 1;
        double theta = Pi /  half_block_size;
        complex Wn(cos(theta), sgn * sin(theta));
        for (int block = 0; block < size; block += delta_step) {
            complex w(1, 0);
            for (int step = 0; step < half_block_size; step++) {
                int i0 = block + step;
                int i1 = block + step + half_block_size;
                complex c0 = x[i0];
                complex c1 = w * x[i1];
                x[i0] = c0 + c1;
                x[i1] = c0 - c1;
                w = w * Wn;
            }
        }
    }
    delete[] bit;
}
3. 对比校验

二者对比的代码

#include<iostream>
#include<algorithm>
#include<cmath>

const double Pi=3.1415926535897932;

// 实现一个复数类
struct complex{
public:
    double x;
    double y;
public:
    complex(){x = y = 0;}
    complex(double xx, double yy){x = xx, y = yy;}
    complex(const complex& elm){x = elm.x, y = elm.y;};
    friend std::ostream& operator<<(std::ostream& os, const complex& c) 
    {return os << "(" << c.x << ", " << c.y << "i)";}
    complex& operator=(const complex& elm)
    {x = elm.x, y = elm.y; return *this;}
    complex& operator=(const complex&& elm)
    {x = elm.x, y = elm.y; return *this;}
    complex operator +(complex elm)
    { return complex(x + elm.x, y + elm.y);}
    complex operator -(complex elm)
    { return complex(x - elm.x, y - elm.y);}
    complex operator *(complex elm)
    {return complex(x * elm.x - y * elm.y, x * elm.y + y * elm.x);}
    complex operator /(double elm)
    {return complex(x/elm, y/elm);}
};

inline void FFT(complex* x, int size, int sgn) {
    int n = 0;
    while ((1 << n) < size)n++;
    int* bit = new int[size];
    bit[0] = 0; 
    for (int i = 0; i < size; i++) {
        bit[i] = (bit[i>>1]>>1) | ((i & 1) << (n - 1));
        if(i < bit[i])std::swap(x[i], x[bit[i]]);
    }
    for (int i = 0; i < n; i++) {
        int half_block_size = 1 << i;
        int delta_step = half_block_size << 1;
        double theta = Pi /  half_block_size;
        complex Wn(cos(theta), sgn * sin(theta));
        for (int block = 0; block < size; block += delta_step) {
            complex w(1, 0);
            for (int step = 0; step < half_block_size; step++) {
                int i0 = block + step;
                int i1 = block + step + half_block_size;
                complex c0 = x[i0];
                complex c1 = w * x[i1];
                x[i0] = c0 + c1;
                x[i1] = c0 - c1;
                w = w * Wn;
            }
        }
    }
    delete[] bit;
}

void fft(complex*x, int size, int sgn){
    if(size <= 1){
        return;
    }
    int half_size = size / 2;
    complex* a0 = new complex[half_size];
    complex* a1 = new complex[half_size];
    for(int i=0;i<half_size;i++){
        a0[i] = x[2*i];
        a1[i] = x[2*i+1];
    }
    fft(a0, half_size, sgn);
    fft(a1, half_size, sgn);
    double theta = 2*Pi/size;
    complex Wn = complex(cos(theta),  sgn * sin(theta));
    complex w = complex(1.0, 0);
    for(int i=0;i<half_size;i++){
        x[i] = a0[i] + w * a1[i];
        x[i+half_size] = a0[i] - w * a1[i];
        w = w * Wn;
    }
    delete[] a0;
    delete[] a1;
}

int main(){ 
    const int size = 1<<4;
    // 限定为 2^n 次方大小
    complex a[size], b[size];
    for(int i=0;i<size;i++) a[i] = complex(i , 0), b[i] = complex(i, 0);
    FFT(a, size, 1);
    fft(b, size, 1);
    std::cout<<"fft result:\n";
    for(int i=0;i<size;i++)
        std::cout<<a[i]<<"\t"<<b[i]<<std::endl;
    FFT(a, size, -1);
    fft(b, size, -1);
    std::cout<<"\n\nifft result:\n";
    for(int i=0;i<size;i++)
        std::cout<<a[i] / size <<"\t"<<b[i] / size <<std::endl;
}