[提高组集训2021] 校游戏
一、题目
有两个字符串 \(A,B\),你需要对于所有 \(k\) 求出:从 \(A\) 中随机选一个长度为 \(k\) 的子串比从 \(B\) 中随机选一个长度为 \(k\) 的子串字典序小的概率、字典序相等的概率、字典序大的概率。
\(|A|,|B|\leq 2\cdot 10^5\)
二、解法
还是不要迷信后缀自动机,后缀数组解决字典序问题有些时候还是简单多了。
考虑计算 \(A<B\) 的概率(第一个子串更大的概率同理)
我们把两个串中间插入一个分隔符一起做一次后缀数组,对于某个属于 \(A\) 的后缀,我们找到排名比它大的 \(B\) 的后缀,考虑它们对 \(A\leq B\) 的贡献,是在 \([1,\min(|A|,|B|)]\) 都加上了 \(1\)
但是这样做要减去 \(A=B\) 的贡献,发现这就是 \(\tt lcp\) 问题,扫描的时候我们对 \(\tt height\) 数组做单调栈,单调栈中的每个元素都有一个负贡献,那么我们在栈顶打一个整体标记即可。
对于 \(A\leq B\) 贡献的计算可以使用权值线段树维护差分标记,也就是对所有已插入的属于 \(B\) 的后缀增加 \(1\) 的标记即可,时间复杂度 \(O(n\log n)\),然后再来一次即可!
三、总结
子串问题转化成后缀问题,用后缀数组来算贡献即可。
//Cause sum gonna arrive
#include <cstdio>
#include <cassert>
#include <cstring>
#include <iostream>
using namespace std;
const int M = 400005;
#define int long long
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,la,lb,p[M],z[M],tmp[M],pre[M],xi[M],da[M];
int suf[M],c[M],x[M],y[M],sa[M],rk[M],h[M];
int tag[2*M],num[2*M];char a[M],b[M],s[M];
//suffix array
void init()
{
int m=256;
for(int i=0;i<=m;i++) c[i]=0;
for(int i=1;i<=n;i++) c[x[i]=s[i]]++;
for(int i=1;i<=m;i++) c[i]+=c[i-1];
for(int i=n;i>=1;i--) sa[c[x[i]]--]=i;
for(int k=1;k<=n;k<<=1)
{
int num=0;
for(int i=n-k+1;i<=n;i++) y[++num]=i;
for(int i=1;i<=n;i++)
if(sa[i]>k) y[++num]=sa[i]-k;
for(int i=0;i<=m;i++) c[i]=0;
for(int i=1;i<=n;i++) c[x[i]]++;
for(int i=1;i<=m;i++) c[i]+=c[i-1];
for(int i=n;i>=1;i--) sa[c[x[y[i]]]--]=y[i];
swap(x,y);
x[sa[1]]=num=1;
for(int i=2;i<=n;i++)
x[sa[i]]=(y[sa[i]]==y[sa[i-1]]
&&y[sa[i]+k]==y[sa[i-1]+k])?num:++num;
if(n==num) break;
m=num;
}
int k=0;
for(int i=1;i<=n;i++) rk[sa[i]]=i;
for(int i=1;i<=n;i++)
{
if(rk[i]==1) continue;
if(k) k--;
int x=i,y=sa[rk[i]-1];
while(x+k<=n && y+k<=n && s[x+k]==s[y+k]) k++;
h[rk[i]]=k;
}
}
//segment tree maintaining tags
void down(int i)
{
if(!tag[i]) return ;
tag[i<<1]+=tag[i];
tag[i<<1|1]+=tag[i];
tag[i]=0;
}
void work(int i,int l,int r,int L,int R)
{
if(l>R || L>r) return ;
if(L<=l && r<=R)
{
tag[i]++;
return ;
}
int mid=(l+r)>>1;down(i);
work(i<<1,l,mid,L,R);
work(i<<1|1,mid+1,r,L,R);
}
int ask(int i,int l,int r,int L,int R)
{
if(L>r || l>R) return 0;
if(L<=l && r<=R) return num[i];
int mid=(l+r)>>1;
return ask(i<<1,l,mid,L,R)
+ask(i<<1|1,mid+1,r,L,R);
}
void ins(int i,int l,int r,int id)
{
if(l==r)
{
tag[i]=0;
num[i]=1;
return ;
}
int mid=(l+r)>>1;down(i);
if(mid>=id) ins(i<<1,l,mid,id);
else ins(i<<1|1,mid+1,r,id);
num[i]=num[i<<1]+num[i<<1|1];
}
void dfs(int i,int l,int r)
{
if(l==r)
{
assert(num[i]);
pre[l+1]-=tag[i];
return ;
}
int mid=(l+r)>>1;down(i);
dfs(i<<1,l,mid);
dfs(i<<1|1,mid+1,r);
}
void work()
{
init();int t=0;
memset(tag,0,sizeof tag);
memset(num,0,sizeof num);
for(int i=1;i<=n;i++)
suf[i]=pre[i]=z[i]=tmp[i]=0;
for(int i=n;i>=1;i--)
{
int x=sa[i];suf[i]=suf[i+1];
if(x<=la)//A
{
int len=la-x+1;
int num=ask(1,1,lb,len+1,lb);//get number
pre[1]+=suf[i];pre[len+1]-=num;
work(1,1,lb,1,len);//give tag(-)
z[t]++;
}
if(x>la+1)//B
{
suf[i]++;
ins(1,1,lb,n-x+1);
}
//upd the stack
while(t && h[i]<=h[p[t]])
{
tmp[1]+=z[t]*(suf[p[t]]-suf[p[t-1]]);
tmp[h[p[t]]+1]-=z[t]*(suf[p[t]]-suf[p[t-1]]);
z[t-1]+=z[t];t--;
}
p[++t]=i,z[t]=0;
}
dfs(1,1,lb);
for(int i=t;i>=1;i--)
{
tmp[1]+=z[i]*(suf[p[i]]-suf[p[i-1]]);
tmp[h[p[i]]+1]-=z[i]*(suf[p[i]]-suf[p[i-1]]);
z[i-1]+=z[i];
}
for(int i=1;i<=lb;i++)
tmp[i]+=tmp[i-1];
for(int i=1;i<=lb;i++)
{
pre[i]+=pre[i-1];
tmp[i]=pre[i]-tmp[i];
}
}
int gcd(int a,int b)
{
return !b?a:gcd(b,a%b);
}
void write(int x)
{
if(x<=9)
{
putchar(x+'0');
return ;
}
write(x/10);
putchar(x%10+'0');
}
void print(int x,int y)
{
int t=gcd(x,y);
x/=t;y/=t;
write(x);
putchar('/');
write(y);
putchar(' ');
}
signed main()
{
freopen("game.in","r",stdin);
freopen("game.out","w",stdout);
scanf("%s",a+1),la=strlen(a+1);
scanf("%s",b+1),lb=strlen(b+1);
for(int i=1;i<=la;i++) s[++n]=a[i];
s[++n]='z'+1;
for(int i=1;i<=lb;i++) s[++n]=b[i];
work();
for(int i=1;i<=lb;i++) xi[i]=tmp[i];
//once again
n=0;
for(int i=1;i<=lb;i++) s[++n]=b[i];
s[++n]='z'+1;
for(int i=1;i<=la;i++) s[++n]=a[i];
swap(la,lb);
work();
for(int i=1;i<=lb;i++) da[i]=tmp[i];
for(int i=1;i<=min(la,lb);i++)
{
int sum=(la-i+1)*(lb-i+1);
print(xi[i],sum);
print(sum-xi[i]-da[i],sum);
print(da[i],sum);
puts("");
}
}