1. 数学原理
欧拉公式:
则 ,
如果从多项式乘法的角度理解,由多项式分解规律可得
2. 代码实现
1. 递归版本的fft
可以得到递归版本的fft
void fft(complex*x, int size, int sgn){
if(size <= 1){
return;
}
int half_size = size / 2;
complex* a0 = new complex[half_size];
complex* a1 = new complex[half_size];
for(int i=0;i<half_size;i++){
a0[i] = x[2*i];
a1[i] = x[2*i+1];
}
fft(a0, half_size, sgn);
fft(a1, half_size, sgn);
double theta = 2*Pi/size;
complex Wn = complex(cos(theta), sgn * sin(theta));
complex w = complex(1.0, 0);
for(int i=0;i<half_size;i++){
x[i] = a0[i] + w * a1[i];
x[i+half_size] = a0[i] - w * a1[i];
w = w * Wn;
}
delete[] a0;
delete[] a1;
}
2. 非递归版本的fft
如果我们要得到非递归版本的,其实有一定程度上的问题
由于每次都是递归选取偶数位置,这里的位置是有一定规律的, 比如一个八个元素的fft顺序变换如下
01234567
0246 | 1357
04 | 26 | 15 | 37
0 | 4 | 2 | 6 | 1 | 5 | 3 | 7
000 000 0
001 100 4
010 010 2
011 110 6
100 001 1
101 101 5
110 011 3
111 111 7
如何求得这个序列呢,这个东西其实就是二进制编码的翻转形式 , 通过下面的代码可以实现
bit[0] = 0;
int n = log2(size);
for (int i = 0; i < size; i++) {
bit[i] = (bit[i>>1]>>1) | ((i & 1) << (n - 1));
}
(i & 1) << (n - 1)) 其实是把i最后一位移动到首位
在求得这个序列之后,可以知道二进制编码相反的数字在一个交换环上,就可以交换得到打乱的顺序。通过这些操作之后可以得到一个非递归版本的 函数
inline void FFT(complex* x, int size, int sgn) {
int n = 0;
while ((1 << n) < size)n++;
int* bit = new int[size];
bit[0] = 0;
for (int i = 0; i < size; i++) {
bit[i] = (bit[i>>1]>>1) | ((i & 1) << (n - 1));
if(i < bit[i])std::swap(x[i], x[bit[i]]);
}
for (int i = 0; i < n; i++) {
int half_block_size = 1 << i;
int delta_step = half_block_size << 1;
double theta = Pi / half_block_size;
complex Wn(cos(theta), sgn * sin(theta));
for (int block = 0; block < size; block += delta_step) {
complex w(1, 0);
for (int step = 0; step < half_block_size; step++) {
int i0 = block + step;
int i1 = block + step + half_block_size;
complex c0 = x[i0];
complex c1 = w * x[i1];
x[i0] = c0 + c1;
x[i1] = c0 - c1;
w = w * Wn;
}
}
}
delete[] bit;
}
3. 对比校验
二者对比的代码
#include<iostream>
#include<algorithm>
#include<cmath>
const double Pi=3.1415926535897932;
// 实现一个复数类
struct complex{
public:
double x;
double y;
public:
complex(){x = y = 0;}
complex(double xx, double yy){x = xx, y = yy;}
complex(const complex& elm){x = elm.x, y = elm.y;};
friend std::ostream& operator<<(std::ostream& os, const complex& c)
{return os << "(" << c.x << ", " << c.y << "i)";}
complex& operator=(const complex& elm)
{x = elm.x, y = elm.y; return *this;}
complex& operator=(const complex&& elm)
{x = elm.x, y = elm.y; return *this;}
complex operator +(complex elm)
{ return complex(x + elm.x, y + elm.y);}
complex operator -(complex elm)
{ return complex(x - elm.x, y - elm.y);}
complex operator *(complex elm)
{return complex(x * elm.x - y * elm.y, x * elm.y + y * elm.x);}
complex operator /(double elm)
{return complex(x/elm, y/elm);}
};
inline void FFT(complex* x, int size, int sgn) {
int n = 0;
while ((1 << n) < size)n++;
int* bit = new int[size];
bit[0] = 0;
for (int i = 0; i < size; i++) {
bit[i] = (bit[i>>1]>>1) | ((i & 1) << (n - 1));
if(i < bit[i])std::swap(x[i], x[bit[i]]);
}
for (int i = 0; i < n; i++) {
int half_block_size = 1 << i;
int delta_step = half_block_size << 1;
double theta = Pi / half_block_size;
complex Wn(cos(theta), sgn * sin(theta));
for (int block = 0; block < size; block += delta_step) {
complex w(1, 0);
for (int step = 0; step < half_block_size; step++) {
int i0 = block + step;
int i1 = block + step + half_block_size;
complex c0 = x[i0];
complex c1 = w * x[i1];
x[i0] = c0 + c1;
x[i1] = c0 - c1;
w = w * Wn;
}
}
}
delete[] bit;
}
void fft(complex*x, int size, int sgn){
if(size <= 1){
return;
}
int half_size = size / 2;
complex* a0 = new complex[half_size];
complex* a1 = new complex[half_size];
for(int i=0;i<half_size;i++){
a0[i] = x[2*i];
a1[i] = x[2*i+1];
}
fft(a0, half_size, sgn);
fft(a1, half_size, sgn);
double theta = 2*Pi/size;
complex Wn = complex(cos(theta), sgn * sin(theta));
complex w = complex(1.0, 0);
for(int i=0;i<half_size;i++){
x[i] = a0[i] + w * a1[i];
x[i+half_size] = a0[i] - w * a1[i];
w = w * Wn;
}
delete[] a0;
delete[] a1;
}
int main(){
const int size = 1<<4;
// 限定为 2^n 次方大小
complex a[size], b[size];
for(int i=0;i<size;i++) a[i] = complex(i , 0), b[i] = complex(i, 0);
FFT(a, size, 1);
fft(b, size, 1);
std::cout<<"fft result:\n";
for(int i=0;i<size;i++)
std::cout<<a[i]<<"\t"<<b[i]<<std::endl;
FFT(a, size, -1);
fft(b, size, -1);
std::cout<<"\n\nifft result:\n";
for(int i=0;i<size;i++)
std::cout<<a[i] / size <<"\t"<<b[i] / size <<std::endl;
}