FFT & NTT

  1. 1 UOJ #34. 多项式乘法
  2. 2 bzoj 2179: FFT快速傅立叶
  3. 3 bzoj 2194: 快速傅立叶之二
  4. 4 bzoj 3527: [Zjoi2014]力
  5. 5 SPOJ TSUM
  6. 6 bzoj 3771

pkusc第一天就有FFT可是我不会做呢
Let's Orz Menci
从多项式乘法到快速傅里叶变换 by Miskcoo(里面有NTT)

1 UOJ #34. 多项式乘法

第一行两个整数 nnmm ,分别表示两个多项式的次数。 第二行 n+1n+1 个整数,分别表示第一个多项式的 00nn 次项前的系数。 第三行 m+1m+1 个整数,分别表示第一个多项式的 00mm 次项前的系数。 数据范围:0n,m105 0 \leq n,m \leq 10^5 .

FFT:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include <complex>
#include <cstdio>
const double pi=acos(-1);
int len1,len2,n,m,rev[262200];
std::complex<double> a[262200],b[262200];
inline void FFT(std::complex<double> *a,int f)
{
for(int i=0;i<n;i++) if(i<rev[i]) std::swap(a[i],a[rev[i]]);
for(int i=1;i<n;i<<=1)
{
std::complex<double> wn(cos(pi/i),f*sin(pi/i));
for(int j=0;j<n;j+=(i<<1))
{
std::complex<double> w(1,0);
for(int k=0;k<i;k++,w*=wn)
{
std::complex<double> x=a[j+k],y=a[j+k+i]*w;
a[j+k]=x+y;a[j+k+i]=x-y;
}
}
}
}
int main()
{
scanf("%d%d",&len1,&len2);
for(int i=0;i<=len1;i++) scanf("%lf",&a[i].real());
for(int i=0;i<=len2;i++) scanf("%lf",&b[i].real());
for(n=1,m=0;n<=len1+len2;n<<=1,m++);
// m=ceil(log(len1+len2+1)/log(2)),n=1<<m;
for(int i=0;i<n;i++) rev[i]=rev[i>>1]>>1|(i&1)<<(m-1);
FFT(a,1);FFT(b,1);
for(int i=0;i<n;i++) a[i]*=b[i];
FFT(a,-1);
for(int i=0;i<=len1+len2;i++) printf("%d ",(int)(a[i].real()/n+0.1));
}

自己写了个结构体表示复数,比complex快

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include <iostream>
#include <cstdio>
#include <cmath>
const double pi=acos(-1);
int len1,len2,n,m,rev[262200];
struct node {
double x,y;
inline node (double a=0,double b=0) : x(a),y(b) {}
inline node operator + (node a) { return node(x+a.x,y+a.y); }
inline node operator - (node a) { return node(x-a.x,y-a.y); }
inline node operator * (node a) { return node(x*a.x-y*a.y,x*a.y+y*a.x); }
}a[262200],b[262200];
inline void FFT(node *a,int f)
{
for(int i=0;i<n;i++) if(i<rev[i]) std::swap(a[i],a[rev[i]]);
for(int i=1;i<n;i<<=1)
{
node wn(cos(pi/i),f*sin(pi/i));
for(int j=0;j<n;j+=(i<<1))
{
node w(1,0);
for(int k=0;k<i;k++,w=w*wn)
{
node x=a[j+k],y=a[j+k+i]*w;
a[j+k]=x+y;a[j+k+i]=x-y;
}
}
}
}
int main()
{
scanf("%d%d",&len1,&len2);
for(int i=0;i<=len1;i++) scanf("%lf",&a[i].x);
for(int i=0;i<=len2;i++) scanf("%lf",&b[i].x);
for(n=1,m=0;n<=len1+len2;n<<=1,m++);
for(int i=0;i<n;i++) rev[i]=rev[i>>1]>>1|(i&1)<<(m-1);
FFT(a,1);FFT(b,1);
for(int i=0;i<n;i++) a[i]=a[i]*b[i];
FFT(a,-1);
for(int i=0;i<=len1+len2;i++) printf("%d ",(int)(a[i].x/n+0.1));
}

NTT:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <iostream>
#include <cstdio>
#define Inv(x) (power(x,P-2))
const int P=(479<<21)+1;
const int G=3;
int len1,len2,n,m,rev[262200],a[262200],b[262200];
inline int power(int x,int y)
{
int z=1;
for(;y;y>>=1,x=(long long)x*x%P)
if(y&1) z=(long long)z*x%P;
return z;
}
inline void NTT(int *a,int f)
{
for(int i=0;i<n;i++) if(i<rev[i]) std::swap(a[i],a[rev[i]]);
for(int i=1,t=1;i<n;i<<=1,t++)
{
int wn=power(G,(P-1)/(1<<t));
if(f==-1) wn=Inv(wn);
for(int j=0;j<n;j+=(i<<1))
for(int k=0,w=1;k<i;k++,w=(long long)w*wn%P)
{
int x=a[j+k],y=(long long)a[j+k+i]*w%P;
a[j+k]=(x+y)%P;a[j+k+i]=(x-y+P)%P;
}
}
}
int main()
{
scanf("%d%d",&len1,&len2);
for(int i=0;i<=len1;i++) scanf("%d",&a[i]);
for(int i=0;i<=len2;i++) scanf("%d",&b[i]);
for(n=1,m=0;n<=len1+len2;n<<=1,m++);
for(int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|(i&1)<<(m-1);
NTT(a,1);NTT(b,1);
for(int i=0;i<n;i++) a[i]=(long long)a[i]*b[i]%P;
NTT(a,-1);n=Inv(n);
for(int i=0;i<=len1+len2;i++) printf("%lld ",(long long)a[i]*n%P);
}

2 bzoj 2179: FFT快速傅立叶

给出两个nn1010进制整数xxyy,你需要计算x×yx\times y
数据范围:n<=60000n<=60000

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include <complex>
#include <cstdio>
const int N=140000;
const double pi=acos(-1);
int n,m,len,rev[N],ans[N];
std::complex<double> a[N],b[N];
inline void FFT(std::complex<double> *a,int f)
{
for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int i=1;i<n;i<<=1)
{
std::complex<double> wn(cos(pi/i),f*sin(pi/i));
for(int j=0;j<n;j+=(i<<1))
{
std::complex<double> w(1,0);
for(int k=0;k<i;k++,w*=wn)
{
std::complex<double> x=a[j+k],y=w*a[j+k+i];
a[j+k]=x+y;a[j+k+i]=x-y;
}
}
}
if(f==-1) for(int i=0;i<n;i++) a[i]/=n;
}
int main()
{
scanf("%d",&n);
for(int i=n-1;i>=0;i--) scanf("%1lf",&a[i].real());
for(int i=n-1;i>=0;i--) scanf("%1lf",&b[i].real());
len=(n<<1)-1;m=ceil(log(n)/log(2))+1;n=1<<m;
for(int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(m-1));
FFT(a,1);FFT(b,1);
for(int i=0;i<n;i++) a[i]*=b[i];
FFT(a,-1);
for(int i=0;i<len;i++)
{
ans[i]+=a[i].real()+0.1;
if(ans[i]>=10)
ans[i+1]+=ans[i]/10,ans[i]%=10;
}
if(a[len]) len++;
for(int i=len-1;i>=0;i--) printf("%d",ans[i]);
}

3 bzoj 2194: 快速傅立叶之二

请计算 C[k]=k<=i<n(a[i]×b[ik])C[k]=\sum\limits _{k<=i<n} (a[i]\times b[i-k]),并且有 n<=105n < = 10 ^ 5a,ba,b中的元素均为小于等于100100的非负整数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include <complex>
#include <cstdio>
const int N=1<<18;
const double pi=acos(-1);
int n,m,k,rev[N];
std::complex<double> a[N],b[N];
inline void FFT(std::complex<double> *a,int f)
{
for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int i=1;i<n;i<<=1)
{
std::complex<double> wn(cos(pi/i),f*sin(pi/i));
for(int j=0;j<n;j+=(i<<1))
{
std::complex<double> w(1,0);
for(int t=0;t<i;t++,w*=wn)
{
std::complex<double> x=a[j+t],y=w*a[i+j+t];
a[j+t]=x+y;a[i+j+t]=x-y;
}
}
}
if(f==-1) for(int i=0;i<n;i++) a[i]/=n;
}
int main()
{
scanf("%d",&n);
m=n;k=ceil(log(n)/log(2))+1;n=1<<k;
for(int i=0;i<m;i++)
scanf("%lf%lf",&a[i].real(),&b[m-i-1].real());
for(int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
FFT(a,1);FFT(b,1);
for(int i=0;i<n;i++) a[i]*=b[i];
FFT(a,-1);
for(int i=m-1;i<=m*2-2;i++)
printf("%d\n",(int)(a[i].real()+0.1));
}

4 bzoj 3527: [Zjoi2014]力

【Description】
给出nn个数qiqi,给出FjFj的定义:Fj=i<jqiqj(ij)2i>jqiqj(ij)2F_j=\sum\limits_{i<j} \frac{q_i q_j}{(i-j)^2}-\sum\limits_{i>j} \frac{q_i q_j}{(i-j)^2}
Ei=FiqiEi=\frac{Fi}{qi},求EiEi.
【Input】
第一行一个整数nn
接下来nn行每行输入一个数,第ii行表示qiqi
【Output】
nn行,第ii行输出EiEi
与标准答案误差不超过1e21e-2即可。
【Sample Input】
5
4006373.885184
15375036.435759
1717456.469144
8514941.004912
1410681.345880
【Sample Output】
-16838672.693
3439.793
7509018.566
4595686.886
10903040.872
【Hint】
对于 的数据,
对于 的数据,
对于 的数据, 0<qi<10000000000 < qi < 1000000000

Ei=j<iqj(ji)2j>iqj(ji)2E_i=\sum\limits_{j<i} \frac{q_j}{(j-i)^2}-\sum\limits_{j>i} \frac{q_j}{(j-i)^2}
ai=qia_i=q_i,bi=1i2b_i=\frac{1}{i^2},
Ei=j=1i1aj×bijj=i+1naj×bijE_i=\sum\limits_{j=1}^{i-1} a_j \times b_{i-j} - \sum\limits_{j=i+1}^{n} a_j \times b_{i-j}

然后前面和后面拆开算,令Ei=AiBi E_i = A_i - B_i
Ai=j=1i1aj×bij A_i = \sum\limits_{j=1}^{i-1} a_j \times b_{i-j} 可以FFT一下
BiB_i 可以把数组反过来,然后FFT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include <complex>
#include <cstdio>
const double pi=acos(-1);
int n,nn,m,rev[265000];
std::complex<double> a[265000],b[265000],c[265000];
inline void FFT(std::complex<double> *a,int f)
{
for(int i=0;i<n;i++) if(i<rev[i]) std::swap(a[i],a[rev[i]]);
for(int i=1;i<n;i<<=1)
{
std::complex<double> wn(cos(pi/i),sin(pi/i)*f);
for(int j=0;j<n;j+=(i<<1))
{
std::complex<double> w(1,0);
for(int k=0;k<i;k++,w*=wn)
{
std::complex<double> x=a[j+k],y=a[j+k+i]*w;
a[j+k]=x+y;a[j+k+i]=x-y;
}
}
}
}
int main()
{
scanf("%d",&n);
for(nn=n,n=1,m=0;(n>>1)<nn;n<<=1,m++);
for(int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|(i&1)<<(m-1);
for(int i=0;i<nn;i++) scanf("%lf",&a[i].real()),b[nn-i-1].real()=a[i].real();
for(int i=1;i<nn;i++) c[i].real()=1.0/((long long)i*i);
FFT(a,1);FFT(b,1);FFT(c,1);
for(int i=0;i<n;i++) a[i]*=c[i],b[i]*=c[i];
FFT(a,-1);FFT(b,-1);
for(int i=0;i<nn;i++) printf("%.3lf\n",(a[i].real()-b[nn-i-1].real())/n);
}

5 SPOJ TSUM

n个不同的数,每次选三个不同的数加起来,问每种结果的个数

,其中a[i]表示值为i的数是否出现。
则想要表示一个数出现两次和三次分别为:

由容斥原理,ans=A3(x)3A(x)B(x)+2C(x)6ans= \frac{A^3(x) - 3A(x)\cdot B(x) + 2C(x)}{6}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <iostream>
#include <cstdio>
#define Inv(x) (power(x,P-2))
const int P=(479<<21)+1;
const int G=3;
int n,rev[1<<17],a[1<<17],b[1<<17],c[1<<17];
inline int power(int x,int y)
{
int z=1;
for(;y;y>>=1,x=(long long)x*x%P)
if(y&1) z=(long long)z*x%P;
return z;
}
inline void NTT(int *a,int f)
{
for(int i=0;i<n;i++) if(i<rev[i]) std::swap(a[i],a[rev[i]]);
for(int i=1,t=1;i<n;i<<=1,t++)
{
int wn=power(G,(P-1)/(1<<t));
if(f==-1) wn=Inv(wn);
for(int j=0;j<n;j+=(i<<1))
for(int k=0,w=1;k<i;k++,w=(long long)w*wn%P)
{
int x=a[j+k],y=(long long)a[j+k+i]*w%P;
a[j+k]=(x+y)%P;a[j+k+i]=(x-y+P)%P;
}
}
}
int main()
{
scanf("%d",&n);
for(int i=1,x;i<=n;i++)
{
scanf("%d",&x);x+=20000;
a[x]=b[x*2]=c[x*3]=1;
}
n=1<<17;
for(int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|(i&1)<<16;
NTT(a,1);NTT(b,1);
for(int i=0;i<n;i++) a[i]=((long long)a[i]*a[i]%P*a[i]%P-3LL*a[i]%P*b[i]%P+P)%P;
NTT(a,-1);
for(int i=0,t=Inv(n),k=Inv(6),x;i<n;i++)
{
x=(((long long)a[i]*t%P+2LL*c[i]%P)%P*k)%P;
if(x) printf("%d : %d\n",i-60000,x);
}
}

6 bzoj 3771

n个不同的数,每次选不超过三个不同的数加起来,问每种结果的个数

solution from Triple

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include <iostream>
#include <cstdio>
#define Inv(x) (power(x,P-2))
const int P=(479<<21)+1;
const int G=3;
int n,rev[1<<17],a[1<<17],b[1<<17],c[1<<17];
inline int power(int x,int y)
{
int z=1;
for(;y;y>>=1,x=(long long)x*x%P)
if(y&1) z=(long long)z*x%P;
return z;
}
inline void NTT(int *a,int f)
{
for(int i=0;i<n;i++) if(i<rev[i]) std::swap(a[i],a[rev[i]]);
for(int i=1,t=1;i<n;i<<=1,t++)
{
int wn=power(G,(P-1)/(1<<t));
if(f==-1) wn=Inv(wn);
for(int j=0;j<n;j+=(i<<1))
for(int k=0,w=1;k<i;k++,w=(long long)w*wn%P)
{
int x=a[j+k],y=(long long)a[j+k+i]*w%P;
a[j+k]=(x+y)%P;a[j+k+i]=(x-y+P)%P;
}
}
}
int main()
{
scanf("%d",&n);
for(int i=1,x;i<=n;i++)
{
scanf("%d",&x);
a[x]++; b[x*2]++; c[x*3]++;
}
n=1<<17;
for(int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|(i&1)<<16;
NTT(a,1);NTT(b,1);
long long x=Inv(2),y=Inv(3),z=Inv(6);
for(int i=0;i<n;i++)
a[i]=(z*a[i]%P*((long long)(a[i]+1)*(a[i]+2)%P+4)%P-x*b[i]%P*(a[i]+1)%P+P)%P;
NTT(a,-1);
for(int i=0,t=Inv(n);i<n;i++)
{
x=((long long)a[i]*t%P+y*c[i]%P)%P;
if(x) printf("%d %lld\n",i,x);
}
}