Furik and Rubik and Sub Array题解

Description

给定一个长度为 \(N\) 只含正数的数组 \(a_i\) 求所有的连续的区间内数的总和的种类数,设数组的总和为 \(S\)

\(N\cdot S\leq 4\cdot10^{10}\)

Solution

因为数组里面都是整数,所以 \(N\leq 2\cdot 10^5\)

所以直接暴力 \(N^2\) 找,要么爆时间(指 \(2\cdot 10^4 \leq N \leq 2 \cdot 10^5\) ),要么爆空间(指 \(1 \leq N \leq 2 \cdot 10^3\)

所以,我们就分三个类型:

1.(\(1 \leq N \leq 2 \cdot 10^3\)

直接开什么 \(map\) 或者 \(set\) 乱搞就行,空间什么的 \(STL\) 根本没在怕。。

时间复杂度 \(O(n^2logn)\)

2.(\(2\cdot 10^3 \leq N \leq 2 \cdot 10^4\)

此时的 \(S\) 最大也就 \(2\cdot10^6\) ,所以开一个 \(vis\) 硬上。。

时间复杂度 \(O(N^2)\)

3.(\(2\cdot 10^4 \leq N \leq 2 \cdot 10^5\)

这时候时间这个硬伤不行了,可以用 FTT/NTT 进一步优化计算。

但我们要先想怎么把 \(N^2\) 的遍历改成 \(N^2\) 的多项式乘法。

我们求的是一段连续的区间,所以用一个数当作多项式里面的指数肯定不能达到要求。

所以我们可以考虑把所有前缀和存进去。

大概就是一正一负:

\(f_n=\sum_{i=1}^{n}x^{sum_i}\)

\(g_n=\sum_{i=1}^{n}x^{-sum_i}\)

这两个卷起来就行。

注意:多项式里指数不能是负数,所以可以集体再加一个 \(sum_n\) ,然后找的时候整体右移就行了。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=5e6+10;
typedef ll arr[N];
const ll mod=998244353;
const ll inv3=332748118;
ll n,m,sum[N],nm,inv,lim=1,fre,id[N],Ans;
arr a,b,ans;
map<ll,bool > mp;
int vis[N];
inline ll inc(ll x,ll y){return x+y>=mod?x+y-mod:x+y;}
inline ll dec(ll x,ll y){return x-y<0?x-y+mod:x-y;}
inline ll read(){
	ll s=0,w=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
	while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
	return s*w;
}
inline ll ksm(ll a,ll b){
	ll tmp=1;
	while(b){
		if(b&1)tmp=(tmp*a)%mod;
		b>>=1,a=(a*a)%mod;
	}
	return tmp;
}
inline void Never_Tell_TLE(ll* NTT,ll sign){
	for(ll i=0;i<=lim;++i)if(i<id[i]){
		ll NTt=NTT[i];
		NTT[i]=NTT[id[i]];
		NTT[id[i]]=NTt;
	}
	for(ll mid=1;mid<lim;mid<<=1){
		ll Unit_root;
		if(sign==1)Unit_root=ksm(3,(mod-1)/(mid<<1));
		else Unit_root=ksm(inv3,(mod-1)/(mid<<1));
		for(ll R=mid<<1,r=0;r<lim;r+=R){
			ll pw=1;
			for(ll l=0;l<mid;++l,pw=(pw*Unit_root)%mod){
				ll butt=NTT[l+r],rfly=(pw*NTT[l+r+mid])%mod;
				NTT[l+r]=inc(butt,rfly);
				NTT[l+r+mid]=dec(butt,rfly);
			}
		}
	}
	if(sign==-1)for(ll i=0;i<=lim;++i)NTT[i]=(NTT[i]*inv)%mod;
}
int main(){
	n=read();
	for(int i=1;i<=n;++i){
		sum[i]=read()+sum[i-1];
	}
	if(n<=4000){
		for(int i=1;i<=n;++i){
			for(int j=i;j<=n;++j){
				if(!mp[sum[j]-sum[i-1]]){
					++Ans;
					mp[sum[j]-sum[i-1]]=1;
				}
			}
		}
		printf("%lld\n",Ans-1);
		return 0;
	}
	if(n<=20000){
		for(int i=1;i<=n;++i){
			for(int j=i;j<=n;++j){
				++vis[sum[j]-sum[i-1]];
			}
		}
		for(int i=1;i<=2e6;++i){
			if(vis[i])++Ans;
		}
		printf("%lld\n",Ans-1);
		return 0;
	}
	nm=sum[n]<<1;
	for(int i=0;i<=n;++i){
		a[sum[n]+sum[i]]=1;
		b[sum[n]-sum[i]]=1;
	}
	for(;lim<=(nm<<1);lim<<=1)++fre;
	inv=ksm(lim,mod-2);
	for(int i=0;i<lim;++i)id[i]=(id[i>>1]>>1)|((i&1)<<(fre-1));
	Never_Tell_TLE(a,1);
	Never_Tell_TLE(b,1);
	for(int i=0;i<=lim;++i)ans[i]=(a[i]*b[i])%mod;
	Never_Tell_TLE(ans,-1);
	for(int i=nm+1;i<=(nm<<1);++i)if(ans[i])++Ans;
	printf("%lld\n",Ans-1);
	return 0;
}

(难得考场过题呀)

posted @ 2021-07-21 19:40  Illusory_dimes  阅读(94)  评论(0编辑  收藏  举报