题解[P7279 光棱碎片]

原题链接

题意:
对求出无序子串对 \((S[l1\dots r1],S[l2\dots r2])\) 的个数,
满足 \(S[l1\dots r1]=S[l2\dots r2]\)\(L\leq a_{r1} \operatorname{xor} a_{r2}+r1-l1+1\leq R\)
其中 \(L,R,a_{1\dots |S|}\) 是给定的。

首先,\(L,R\) 的限制可以通过差分转为 \(\geq lim\) 的贡献,算两次。

所以只需看 \(a_{r_1} \operatorname{xor} a_{r_1}+r1-l1+1\geq lim\)

扔到 \(\text{SAM}\)\(\text{parent}\) 树上后,对每个子串的 \(\text{endpos}\) 考虑。设点 \(i,j\) 分别对应原序列中 \(r_1,r_2\) 的位置,\(x=\operatorname{lca}(i,j)\) ,则上式为 \(a_{r_1}\operatorname{xor} a_{r_2}+[1,len_x]\geq lim\) 。其中 \([1,len_x]\) 是因为任何比其 \(\text{lcs}\) 短的子串都能相等。(\([1,len_x]\) 指这一段区间,写法可能不规范但应该容易理解)

注意此时每个 \(a_{r_1}\operatorname{xor}a_{r_2}\) 只要满足以上条件都会算入一次,因为会对应不同位置的子串。

也就是说,如果 \(a_{r_1}\operatorname{xor}a_{r_2}\geq lim-1\) ,则 \([1,len_x]\) 区间内都能满足条件,会被算入 \(len_x\) 次,只需求出个数。但如果 \(a_{r_1}\operatorname{xor}a_{r_2}\in[lim-len_x,lim-1)\) ,则会被算入 \(a_{r_1}\operatorname{xor}a_{r_2}-lim+len_x+1\) 次,这就要求出个数,以及这部分的和。

分析到此结束,实现呼之欲出。

对每个 \(x\),统计子树内路径 \(\text{lca}\)\(x\) 的贡献。

通过 \(\text{dsu on tree}\) 把之前子树内的 \(a_{r_1}\) 插到 \(\text{01trie}\) 中,遍历轻儿子子树时先算作为 \(a_{r_2}\) 的贡献,然后再将其插入 \(\text{01trie}\)

而算贡献时要用 \(\text{01trie}\) 求出 \(\operatorname{xor}v\)\(\geq k\) 的个数,以及和。

\(\text{01trie}\) 上找出 \(v\operatorname{xor} s=k\) 的一条链 \(s\) ,如果这一位上 \(k\)\(0\) ,则找到 \(s\) 的另一个儿子,使这一位为 \(1\)。个数即为子树和,而求和可以在 \(\text{01trie}\) 上拆位实现。由于是 \(\geq k\),最终到的点同样要算入。

时间复杂度为 \(O(n\log^3 n)\)\(\text{dsu on tree,01trie}\) 以及拆位各带一个 \(\log\) ,空间复杂度为 \(O(n\log^2 n)\)

代码:

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1e5+10,K=20,mod=998244353;
int n,m,x,y,xx,last=1,tot=1,ts,rt,t_tot,ti,l,lx,ksb;
int lim_l,lim_r,zt1[2],zt2[2];ll tmp,tmps,tt,tts,res,ans;
int mxlen[N<<1],pos[N<<1],link[N<<1],trans[N<<1][26];
int to[N<<1],nextn[N<<1],h[N<<1],edg;bool b,bb,tb[K+1];
int size[N<<1],son[N<<1],w[N<<1],a[N];char ch,s[N];
inline void read(int &x){
	x=0;ch=getchar();while(ch<48)ch=getchar();
	while(ch>47)x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
}
void write(int x){if(x>9)write(x/10);putchar(48+x%10);}
inline void extend(char a){
	y=last;ts=++tot;mxlen[ts]=mxlen[y]+1;
	for(;y&&!trans[y][a];y=link[y])trans[y][a]=ts;
	if(!y)link[ts]=1;
	else {
		x=trans[y][a];
		if(mxlen[x]==mxlen[y]+1)link[ts]=x;
		else {
			xx=++tot;link[xx]=link[x];mxlen[xx]=mxlen[y]+1;
			memcpy(trans[xx],trans[x],sizeof(trans[xx]));
			for(;trans[y][a]==x;y=link[y])trans[y][a]=xx;
			link[ts]=link[x]=xx;
		}
	}
	last=ts;
}
inline void add(int x,int y){to[++edg]=y;nextn[edg]=h[x],h[x]=edg;}
void init(int x){
	int i,y;size[x]=1;
	for(i=h[x];y=to[i];i=nextn[i]){
		init(y);size[x]+=size[y];
		if(size[y]>size[son[x]])son[x]=y;
	}
}
struct trie{int son[2],sz,t[K+1];}t[N*(K+1)];
void update(int &k,int x,int i,int v){
	if(!k)k=++t_tot;t[k].sz+=v;
	for(ti=K;~ti;--ti)if(tb[ti])t[k].t[ti]+=v;
	if(~i)update(t[k].son[(x>>i)&1],x,i-1,v);
}
void inquiry(int k,int x,int i){
	if(~i){
		b=(l>>i)&1;bb=b^((x>>i)&1);
		if(!b)tmp+=t[t[k].son[!bb]].sz;
		inquiry(t[k].son[bb],x,i-1);
	}
	else tmp+=t[k].sz;
}//求个数
void inquirys(int k,int x,int i){
	if(~i){
		b=(l>>i)&1;bb=b^((x>>i)&1);
		if(!b&&(ksb=t[k].son[!bb])){
			for(ti=K;~ti;--ti){
				if(tb[ti])tmps=(tmps+(1ll<<ti)*(t[ksb].sz-t[ksb].t[ti]+mod)%mod)%mod;
				else tmps=(tmps+(1ll<<ti)*t[ksb].t[ti]%mod)%mod;
			}
		}
		inquirys(t[k].son[bb],x,i-1);
	}
	else if(k)tmps=(tmps+1ll*t[k].sz*l)%mod;
}//求和
inline void update(int x,int v){
	for(ti=K;~ti;--ti)tb[ti]=(x>>ti)&1;
	update(rt,x,K,v);
}
inline void inquiry(int x,int lim,bool b){
	tmp=0;l=zt1[b];inquiry(rt,x,K);
	res=(res+tmp*lx%mod)%mod;tt=tmp;//>=lim-1的个数
	tmp=0;l=zt2[b];inquiry(rt,x,K);tt=tmp-tt;
	for(ti=K;~ti;--ti)tb[ti]=(x>>ti)&1;
	tmps=0;l=zt2[b];inquirys(rt,x,K);tts=tmps;
	tmps=0;l=zt1[b];inquirys(rt,x,K);tts-=tmps;
	res=(res-tt*(lim-lx-1+mod)%mod+tts)%mod;
//求[lim-lx,lem-1)的个数以及和
}
inline void inquiry(int x){
	res=0;inquiry(x,lim_l,false);ans=(ans+res)%mod;
	res=0;inquiry(x,lim_r,true);ans=(ans-res)%mod;
}
void calc(int x){
	if(pos[x])inquiry(w[x]);register int i,y;
	for(i=h[x];y=to[i];i=nextn[i])calc(y);
}
void dfs(int x){
	if(pos[x])update(w[x],1);register int i,y;
	for(i=h[x];y=to[i];i=nextn[i])dfs(y);
}
void clear(int x){
	if(pos[x])update(w[x],-1);register int i,y;
	for(i=h[x];y=to[i];i=nextn[i])clear(y);
}
void solve(int x){
	register int i,y;
	for(i=h[x];y=to[i];i=nextn[i])if(y^son[x])solve(y),clear(y);
	if(son[x])solve(son[x]);
	if(lx=mxlen[x]){
		zt2[0]=max(lim_l-lx,0);zt2[1]=max(lim_r-lx,0);//注意要和0取max
		if(pos[x])inquiry(w[x]),update(w[x],1);
		for(i=h[x];y=to[i];i=nextn[i])if(y^son[x])calc(y),dfs(y);
	}
}
main(){
	while(ch<97)ch=getchar();register int i;
	while(ch>96)s[++n]=ch-'a',ch=getchar();
	for(i=1;i<=n;++i)extend(s[i]),pos[ts]=i;
	for(i=2;i<=tot;++i)add(link[i],i);init(1);
	for(i=1;i<=n;++i)read(a[i]);init(1);
	for(i=1;i<=tot;++i)if(pos[i])w[i]=a[pos[i]];
	read(lim_l);zt1[0]=max(lim_l-1,0);
	read(lim_r);zt1[1]=lim_r;++lim_r;
	solve(1);ans=(ans%mod+mod)%mod;write(ans);
}
posted @ 2021-10-30 11:31  Y_B_X  阅读(32)  评论(0编辑  收藏  举报