题解[P7279 光棱碎片]
题意:
对求出无序子串对 \((S[l1\dots r1],S[l2\dots r2])\) 的个数,
满足 \(S[l1\dots r1]=S[l2\dots r2]\) 且 \(L\leq a_{r1} \operatorname{xor} a_{r2}+r1-l1+1\leq R\)。
其中 \(L,R,a_{1\dots |S|}\) 是给定的。
首先,\(L,R\) 的限制可以通过差分转为 \(\geq lim\) 的贡献,算两次。
所以只需看 \(a_{r_1} \operatorname{xor} a_{r_1}+r1-l1+1\geq lim\)。
扔到 \(\text{SAM}\) 的 \(\text{parent}\) 树上后,对每个子串的 \(\text{endpos}\) 考虑。设点 \(i,j\) 分别对应原序列中 \(r_1,r_2\) 的位置,\(x=\operatorname{lca}(i,j)\) ,则上式为 \(a_{r_1}\operatorname{xor} a_{r_2}+[1,len_x]\geq lim\) 。其中 \([1,len_x]\) 是因为任何比其 \(\text{lcs}\) 短的子串都能相等。(\([1,len_x]\) 指这一段区间,写法可能不规范但应该容易理解)
注意此时每个 \(a_{r_1}\operatorname{xor}a_{r_2}\) 只要满足以上条件都会算入一次,因为会对应不同位置的子串。
也就是说,如果 \(a_{r_1}\operatorname{xor}a_{r_2}\geq lim-1\) ,则 \([1,len_x]\) 区间内都能满足条件,会被算入 \(len_x\) 次,只需求出个数。但如果 \(a_{r_1}\operatorname{xor}a_{r_2}\in[lim-len_x,lim-1)\) ,则会被算入 \(a_{r_1}\operatorname{xor}a_{r_2}-lim+len_x+1\) 次,这就要求出个数,以及这部分的和。
分析到此结束,实现呼之欲出。
对每个 \(x\),统计子树内路径 \(\text{lca}\) 为 \(x\) 的贡献。
通过 \(\text{dsu on tree}\) 把之前子树内的 \(a_{r_1}\) 插到 \(\text{01trie}\) 中,遍历轻儿子子树时先算作为 \(a_{r_2}\) 的贡献,然后再将其插入 \(\text{01trie}\)。
而算贡献时要用 \(\text{01trie}\) 求出 \(\operatorname{xor}v\) 后 \(\geq k\) 的个数,以及和。
在 \(\text{01trie}\) 上找出 \(v\operatorname{xor} s=k\) 的一条链 \(s\) ,如果这一位上 \(k\) 为 \(0\) ,则找到 \(s\) 的另一个儿子,使这一位为 \(1\)。个数即为子树和,而求和可以在 \(\text{01trie}\) 上拆位实现。由于是 \(\geq k\),最终到的点同样要算入。
时间复杂度为 \(O(n\log^3 n)\),\(\text{dsu on tree,01trie}\) 以及拆位各带一个 \(\log\) ,空间复杂度为 \(O(n\log^2 n)\)。
代码:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1e5+10,K=20,mod=998244353;
int n,m,x,y,xx,last=1,tot=1,ts,rt,t_tot,ti,l,lx,ksb;
int lim_l,lim_r,zt1[2],zt2[2];ll tmp,tmps,tt,tts,res,ans;
int mxlen[N<<1],pos[N<<1],link[N<<1],trans[N<<1][26];
int to[N<<1],nextn[N<<1],h[N<<1],edg;bool b,bb,tb[K+1];
int size[N<<1],son[N<<1],w[N<<1],a[N];char ch,s[N];
inline void read(int &x){
x=0;ch=getchar();while(ch<48)ch=getchar();
while(ch>47)x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
}
void write(int x){if(x>9)write(x/10);putchar(48+x%10);}
inline void extend(char a){
y=last;ts=++tot;mxlen[ts]=mxlen[y]+1;
for(;y&&!trans[y][a];y=link[y])trans[y][a]=ts;
if(!y)link[ts]=1;
else {
x=trans[y][a];
if(mxlen[x]==mxlen[y]+1)link[ts]=x;
else {
xx=++tot;link[xx]=link[x];mxlen[xx]=mxlen[y]+1;
memcpy(trans[xx],trans[x],sizeof(trans[xx]));
for(;trans[y][a]==x;y=link[y])trans[y][a]=xx;
link[ts]=link[x]=xx;
}
}
last=ts;
}
inline void add(int x,int y){to[++edg]=y;nextn[edg]=h[x],h[x]=edg;}
void init(int x){
int i,y;size[x]=1;
for(i=h[x];y=to[i];i=nextn[i]){
init(y);size[x]+=size[y];
if(size[y]>size[son[x]])son[x]=y;
}
}
struct trie{int son[2],sz,t[K+1];}t[N*(K+1)];
void update(int &k,int x,int i,int v){
if(!k)k=++t_tot;t[k].sz+=v;
for(ti=K;~ti;--ti)if(tb[ti])t[k].t[ti]+=v;
if(~i)update(t[k].son[(x>>i)&1],x,i-1,v);
}
void inquiry(int k,int x,int i){
if(~i){
b=(l>>i)&1;bb=b^((x>>i)&1);
if(!b)tmp+=t[t[k].son[!bb]].sz;
inquiry(t[k].son[bb],x,i-1);
}
else tmp+=t[k].sz;
}//求个数
void inquirys(int k,int x,int i){
if(~i){
b=(l>>i)&1;bb=b^((x>>i)&1);
if(!b&&(ksb=t[k].son[!bb])){
for(ti=K;~ti;--ti){
if(tb[ti])tmps=(tmps+(1ll<<ti)*(t[ksb].sz-t[ksb].t[ti]+mod)%mod)%mod;
else tmps=(tmps+(1ll<<ti)*t[ksb].t[ti]%mod)%mod;
}
}
inquirys(t[k].son[bb],x,i-1);
}
else if(k)tmps=(tmps+1ll*t[k].sz*l)%mod;
}//求和
inline void update(int x,int v){
for(ti=K;~ti;--ti)tb[ti]=(x>>ti)&1;
update(rt,x,K,v);
}
inline void inquiry(int x,int lim,bool b){
tmp=0;l=zt1[b];inquiry(rt,x,K);
res=(res+tmp*lx%mod)%mod;tt=tmp;//>=lim-1的个数
tmp=0;l=zt2[b];inquiry(rt,x,K);tt=tmp-tt;
for(ti=K;~ti;--ti)tb[ti]=(x>>ti)&1;
tmps=0;l=zt2[b];inquirys(rt,x,K);tts=tmps;
tmps=0;l=zt1[b];inquirys(rt,x,K);tts-=tmps;
res=(res-tt*(lim-lx-1+mod)%mod+tts)%mod;
//求[lim-lx,lem-1)的个数以及和
}
inline void inquiry(int x){
res=0;inquiry(x,lim_l,false);ans=(ans+res)%mod;
res=0;inquiry(x,lim_r,true);ans=(ans-res)%mod;
}
void calc(int x){
if(pos[x])inquiry(w[x]);register int i,y;
for(i=h[x];y=to[i];i=nextn[i])calc(y);
}
void dfs(int x){
if(pos[x])update(w[x],1);register int i,y;
for(i=h[x];y=to[i];i=nextn[i])dfs(y);
}
void clear(int x){
if(pos[x])update(w[x],-1);register int i,y;
for(i=h[x];y=to[i];i=nextn[i])clear(y);
}
void solve(int x){
register int i,y;
for(i=h[x];y=to[i];i=nextn[i])if(y^son[x])solve(y),clear(y);
if(son[x])solve(son[x]);
if(lx=mxlen[x]){
zt2[0]=max(lim_l-lx,0);zt2[1]=max(lim_r-lx,0);//注意要和0取max
if(pos[x])inquiry(w[x]),update(w[x],1);
for(i=h[x];y=to[i];i=nextn[i])if(y^son[x])calc(y),dfs(y);
}
}
main(){
while(ch<97)ch=getchar();register int i;
while(ch>96)s[++n]=ch-'a',ch=getchar();
for(i=1;i<=n;++i)extend(s[i]),pos[ts]=i;
for(i=2;i<=tot;++i)add(link[i],i);init(1);
for(i=1;i<=n;++i)read(a[i]);init(1);
for(i=1;i<=tot;++i)if(pos[i])w[i]=a[pos[i]];
read(lim_l);zt1[0]=max(lim_l-1,0);
read(lim_r);zt1[1]=lim_r;++lim_r;
solve(1);ans=(ans%mod+mod)%mod;write(ans);
}