TSUM - Triple Sums
题目描述
给定一个由 N 个不同整数组成的序列 s
考虑序列中不同索引的三个整数的所有可能的和。
对于每个可能的和,输出生成它的不同索引三元组的数量。
题解
在满足\(s_i+s_j+s_k=v\)的条件下
令\(U=\{(i,j,k) | 1\leq i,j,k \leq N\}\)
\(\:\:\:A=\{(i,j,k) | i=j,1\leq i,j,k \leq N\}\)
\(\:\:\:B=\{(i,j,k) | i=k,1\leq i,j,k \leq N\}\)
\(\:\:\:C=\{(i,j,k) | j=k,1\leq i,j,k \leq N\}\)
\(ans=F(U)-F(A)-F(B)-F(C)+F(A\bigcup B)+F(A \bigcup C)+F(B \bigcup C)-F(A \bigcup B \bigcup C)\)
\(\:\:\:\:\:\:\:\:=F(U)-3F(A)+2F(A \bigcup B)\)
设\(S(x)=\sum_{i=1}^N x^{s_i}, S_2(x)=\sum_{i=1}^N x^{2s_i}, S_3=\sum_{i=1}^N x^{3s_i}\)
\(ans=\frac{S(x)^3-3S_2(x)S(x)+2S_3(x)}{3!}\)
对于分子就可用FFT分别求多项式相乘
注意上述分子的过程可以在FFT时点值相加,具体见代码,注意精度
因为本题有负数,数组往右移就行,但多项式相乘后,数组移动长度要增长,具体见代码
点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#define ll long long
using namespace std;
const int maxn=2e6+10101;
const int MOD=1e9+7;
const int inf=2147483647;
const double pi=acos(-1);
int read(){
int x=0,f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
return x*f;
}
typedef complex<double> cd;
int N,rev[maxn],s[maxn],ans[maxn];
cd S1[maxn],s1[maxn],S2[maxn],S3[maxn];
void get(int bit){
for(int i=0;i<(1<<bit);i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
}
void dft(cd *u,int val,int n){
for(int i=0;i<n;i++)if(i<rev[i])swap(u[i],u[rev[i]]);
for(int i=1;i<n;i<<=1){
cd wn(cos(pi/i),val*sin(pi/i));
for(int j=0;j<n;j+=(i<<1)){
cd w(1,0);
for(int k=0;k<i;k++,w*=wn){
cd x=u[j+k],y=w*u[j+k+i];
u[j+k]=x+y;
u[j+k+i]=x-y;
}
}
}
return ;
}
int main(){
N=read();for(int i=1;i<=N;i++)s[i]=read();
for(int i=1;i<=N;i++){
S1[s[i]+20000]+=1;s1[s[i]+20000]+=1;
if(2*s[i]<=60000)S2[s[i]*2+40000]+=1;
if(s[i]*3<=60000 && 3*s[i]+60000>=0)S3[s[i]*3+60000]+=1;
}
int len=0,n;
for(n=1;n<=120000;n<<=1)len++;get(len);
dft(S1,1,n);dft(S2,1,n);for(int i=0;i<=n;i++)S1[i]=S1[i]*S1[i]*S1[i]-(cd)(3)*S2[i]*S1[i];dft(S1,-1,n);
for(int i=-60000;i<=60000;i++){
ll ans=(ll)((ll)(S1[i+60000].real()/n+0.5)+2ll*(ll)S3[i+60000].real())/6;
if(ans)printf("%d : %lld\n",i,ans);
}
return 0;
}