FFT 笔记
参考:yyc dalao 的讲课 ppt,wzr dalao 的博客。
一、前置知识
1. 多项式
就是初中课本上说的多项式,但是只有“一元”,如
那么对于一个
下文无特殊说明,均用小写字母表示一个对应大写字母多项式的系数
2. 卷积
设
则称为“
如我们熟知的两个多项式相乘,就是一种“加卷积”。
如
,设前面是 ,后面是 ,则 。 求得
,即 ,上述式子也恰好是 。
3. 点值 / 插值
点值就是多项式取特值时的值。
由初中函数知识,可以知道
4. 复数
(1). 定义
定义
称
虚数的一般式:
(2). 复平面
将
复平面上一个点对应一个虚数。
连接原点
线段与
有欧拉公式
证明:
(3). 四则运算
和实数四则运算一致,将
若有
对于乘法,有棣莫弗定理
- 两复数相乘,模长相乘,幅角相加。
证明不高深,就是暴力拆式子再合回去,不写了。
5. 弧度制
圆心角(可以大于
易得
弧度制可以直接上
6. 单位根
(1). 定义
若
不要看串了,这里都是指的
(2). 构造
由于虚数相乘幅角相加,所以一个简单的想法是在复平面上以原点为圆心画个半径为
这个圆的周长是
我也不知道哪里画复平面啊,直接偷 yyc 的图了:

(3). 性质
结合上图可知:
。 。
这样的特殊性将有利于我们做 FFT。
二、FFT
1. 简述
就是求
由点值和插值的知识,可以从
所以现在问题变成了:快速的求点值(DFT),快速的插值(IDFT)。
2. DFT
对于
你可以发现
接着考虑把我们的单位根知识喂进去,根据学长的笔记,如果将
代入
有个小 bug,
对于填进
里的,就相当于 ,后面的除数就是 (前面讲过),故不影响。 而对于
的系数,此时就不能是 了,而应该是 ,即 ,那显然就是符号从 。
注意到
所以就可以愉快的分治了!!!
但是这不是线段树,不能瞎分,所以得把
3. IDFT
我这个水平推 ** 呢,感兴趣的移步 P3803 题解区,出结论:
对点积做一次 DFT,然后把
4. 优化
首先经典优化:用 STL 小心被卡常,手写复数类。
递归的时候定义太多数组,时空肯定不优。
直接拿出观察力:
要求
求
。 - 求
:求 ,再求 。 - 求
:求 ,再求 。
- 求
求
。 - 求
,求 ,再求 。 - 求
,求 ,再求 。
- 求
顺序为:
对应二进制:
而原序列的:
所以就是二进制颠倒一下就能递推了!
二进制颠倒的式子递推求:假设我们知道当前二进制数,去掉最后一位的颠倒数,我们将这个颠倒数右移一位,把开头设成当前二进制数的最后一位就行了。说的有点抽象,上代码!
#include <bits/stdc++.h>
#define int long long
#define IOS ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
using namespace std;
const int N = 4e6 + 10;
const double pi = acos(-1);
int n, m, r[N];
struct Complex {
double x, y;
const Complex operator+(const Complex &tmp) const {
return {x + tmp.x, y + tmp.y};
}
const Complex operator-(const Complex &tmp) const {
return {x - tmp.x, y - tmp.y};
}
const Complex operator*(const Complex &tmp) const {
return {x * tmp.x - y * tmp.y, x * tmp.y + tmp.x * y};
}
} a[N], b[N], c[N], p[N];
void fft(int mx, Complex *a) {
for (int i = 0; i < 1 << mx; i++)
if (i < r[i])
swap(a[i], a[r[i]]);
for (int i = 0; i < mx; i++) {
Complex w = {cos(pi / (1 << i)), sin(pi / (1 << i))};
p[0] = {1, 0};
for (int j = 1; j < 1 << i; j++)
p[j] = p[j - 1] * w;
for (int j = 0; j < 1 << mx; j++) {
if (!(j & (1 << i))) {
Complex a1 = a[j], a2 = a[j + (1 << i)], tmp = a2 * p[j & ((1 << i) - 1)];
a[j] = a1 + tmp;
a[j + (1 << i)] = a1 - tmp;
}
}
}
}
signed main() {
IOS;
cin >> n >> m;
for (int i = 0; i <= n; i++)
cin >> a[i].x;
for (int i = 0; i <= m; i++)
cin >> b[i].x;
int mx = 1;
while ((1 << mx) <= n + m)
mx++;
for (int i = 1; i < 1 << mx; i++)
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (mx - 1)); // 颠倒
fft(mx, a);
fft(mx, b);
for (int i = 0; i < 1 << mx; i++)
c[i] = a[i] * b[i];
fft(mx, c);
reverse(c + 1, c + (1 << mx));
for (int i = 0; i <= n + m; i++)
cout << (int)(c[i].x / (1 << mx) + 0.5) << " "; // 向上取整防止精度问题
return 0;
}