Loj #2479. 「九省联考 2018」制胡窜

Loj #2479. 「九省联考 2018」制胡窜

题目描述

对于一个字符串 \(S\),我们定义 \(|S|\) 表示 \(S\) 的长度。

接着,我们定义 \(S_i\) 表示 \(S\) 中第 \(i\) 个字符,\(S_{L,R}\) 表示由 \(S\) 中从左往右数,第 \(L\) 个字符到第 \(R\) 个字符依次连接形成的字符串。特别的,如果 \(L > R\) ,或者 \(L < [1, |S|]\), 或者 \(R < [1, |S|]\) 我们可以认为 \(S_{L,R}\) 为空串。

给定一个长度为 \(n\) 的仅由数字构成的字符串 \(S\),现在有 \(q\) 次询问,第 \(k\) 次询问会给出 \(S\) 的一个字符串 \(S_{l,r}\) ,请你求出有多少对 \((i, j)\),满足 \(1 \le i < j \le n\)\(i + 1 \lt j\),且 \(S_{l,r}\) 出现在 \(S_{1,i}\) 中或 \(S_{i+1, j−1}\) 中或 \(S_{j,n}\) 中。

输入格式

输入的第一行包含两个整数 \(n, q\)

第二行包含一个长度为 \(n\) 的仅由数字构成的字符串 \(S\)

接下来 \(q\) 行,每行两个正整数 \(l\)\(r\),表示此次询问的子串是 \(S_{l,r}\)

输出格式

对于每个询问,输出一个整数表示合法的数对个数。

数据范围与提示

对于所有测试数据,\(1 \le n \le 10^5\)\(1 \le q \le 3 · 10^5\)\(1 \le l \le r \le n\)

\(\\\)

感觉这道题细节贼烦人,正式考试的话估计可以刚一整场。

首先建后缀自动机,然后在使用线段树合并维护\(endpos\)集合。

询问的时候就先在\(fail\)树上倍增找到给定字符串出现的节点。然后我们将合法的\((i,j)\)二元组分为以下三种情况:

  1. \(S_{1,i}\)中出现
  2. \(S_{1,i}\)中未出现,\(S_{j,n}\)中出现
  3. \(S_{1,i},S_{j,n}\)中为出现,\(S_{i+1,j-1}\)中出现。

前两种情况很好算,找到位置最靠前以及最靠后的\(endpos\)就行了。

下面来考虑第三种情况。假设最靠前的\(endpos\)\(L\),最靠后的是\(R\),字符串长度为\(len\)。显然\(i<L,j>R-len+1\)

我们先考虑一种暴力做法:枚举\(j\in[R-len+2,n]\),然后算对于每个\(j\)有多少个可行的\(i\)。设\(<j\)的最大的\(endpos\)\(mx\),显然可行的\(i\)只与\(mx\)有关,为\(\min\{L,mx-len\}\)

理解了这个暴力做法过后正解就差不多知道了。对于线段树上每个节点,我们令每个位置的权值为其左边第一个\(endpos\)(如果没有则为\(0\)),\(sum\)为这些位置的权值和,\(rmax\)为最右边的\(endpos\)\(lempty\)为左边有多少个位置没有\(endpos\)。注意上述的信息只考虑了线段树所表示的区间,区间外的\(endpos\)不对其产生任何影响。正因为如此,在询问的时候先遍历左儿子,动态更新最右边的\(endpos\),再遍历右儿子计算答案。

道理很简单,就是要注意的边界情况有点多。。。

代码:

#include<bits/stdc++.h>
#define ll long long
#define N 200005

using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}

int n,m;
char s[N];
int fail[N<<1],mxlen[N<<1];
int ch[N<<1][10];
int last=1,cnt=1;
int pos[N<<1],id[N<<1];
ll ss[N];
void Insert(int f,int P) {
	int p=last;
	int v=++cnt;
	pos[v]=P;
	id[P]=v;
	last=v;
	mxlen[v]=mxlen[p]+1;
	while(p&&!ch[p][f]) ch[p][f]=v,p=fail[p];
	if(!p) return fail[v]=1,void();
	int sn=ch[p][f];
	if(mxlen[sn]==mxlen[p]+1) return fail[v]=sn,void();
	int New=++cnt;
	mxlen[New]=mxlen[p]+1;
	memcpy(ch[New],ch[sn],sizeof(ch[sn]));
	fail[New]=fail[sn];
	fail[sn]=fail[v]=New;
	while(p&&ch[p][f]==sn) ch[p][f]=New,p=fail[p];
}

int fa[N<<1][20];
vector<int>e[N<<1];
int rt[N<<1];
int ls[N*50],rs[N*50];
int tag[N*50];
int emp[N*50],rmax[N*50];
ll sum[N*50];
int tot;
int lx,rx;

void update(int v,int lx,int rx) {
	sum[v]=sum[ls[v]]+sum[rs[v]];
	int mid=lx+rx>>1;
	ll R=rs[v]?emp[rs[v]]:rx-mid;
	sum[v]+=1ll*rmax[ls[v]]*R;
	if(!ls[v]||emp[ls[v]]==mid-lx+1) {
		emp[v]=mid-lx+1+R;
	} else {
		emp[v]=emp[ls[v]];
	}
	if(rs[v]) rmax[v]=rmax[rs[v]];
	else rmax[v]=rmax[ls[v]];
}

void Insert(int &v,int lx,int rx,int p) {
	v=++tot;
	tag[v]=1;
	if(lx==rx) {
		sum[v]=p;
		rmax[v]=lx;
		return ;
	}
	int mid=lx+rx>>1;
	if(p<=mid) Insert(ls[v],lx,mid,p);
	else Insert(rs[v],mid+1,rx,p);
	update(v,lx,rx);
}

int Merge(int a,int b,int lx,int rx) {
	if(!a||!b) return a+b;
	int v=++tot;
	int mid=lx+rx>>1;
	ls[v]=Merge(ls[a],ls[b],lx,mid);
	rs[v]=Merge(rs[a],rs[b],mid+1,rx);
	update(v,lx,rx);
	return v;
}

void dfs(int v) {
	for(int i=1;i<=18;i++) fa[v][i]=fa[fa[v][i-1]][i-1];
	if(pos[v]) Insert(rt[v],lx,rx,pos[v]);
	for(int i=0;i<e[v].size();i++) {
		int to=e[v][i];
		dfs(to);
		rt[v]=Merge(rt[v],rt[to],lx,rx);
	}
}

int Find(int l,int r) {
	int v=id[r];
	for(int i=18;i>=0;i--)
		if(fa[v][i]&&mxlen[fa[v][i]]>=r-l+1)
			v=fa[v][i];
	return v;
}

int query_mn(int v,int lx,int rx,int lim) {
	if(!v||rx<lim) return 0;
	if(lx==rx) return lx;
	int mid=lx+rx>>1;
	int x=query_mn(ls[v],lx,mid,lim);
	if(x) return x;
	else return query_mn(rs[v],mid+1,rx,lim);
}

int query_mx(int v,int lx,int rx) {
	if(lx==rx) return lx;
	int mid=lx+rx>>1;
	if(rs[v]) return query_mx(rs[v],mid+1,rx);
	else return query_mx(ls[v],lx,mid);
}

ll query_s(int v,int lx,int rx,int l,int r,int &L) {
	if(lx>r) return 0;
	if(rx<l) {
		L=max(L,rmax[v]);
		return 0;
	}
	if(l<=lx&&rx<=r) {
		ll x=!v?rx-lx+1:emp[v];
		ll ans=sum[v]+1ll*x*L;
		L=max(L,rmax[v]);
		return ans;
	}
	int mid=lx+rx>>1;
	return query_s(ls[v],lx,mid,l,r,L)+query_s(rs[v],mid+1,rx,l,r,L);
}

ll solve(int v,int len) {
	int mn=query_mn(rt[v],lx,rx,1),mx=query_mx(rt[v],lx,rx);
	ll ans=0;
	if(mn==mx) {
		if(mx<n) ans+=ss[n-mx-1];
		if(mn-len+1>1) ans+=ss[mn-len-1];
		ans+=1ll*(n-mx)*(mn-len);
		return ans;
	}
	if(mn<n) ans+=ss[n-mn-1];
	if(mx-len+1>1) ans+=ss[mx-len-1];
	if(mx-len+1>mn+1) ans-=ss[mx-len-mn];
	int ed=query_mn(rt[v],lx,rx,mn+len-1);
	if(ed) {
		ed=max(ed,mx-len+1);
		ans+=1ll*(n-ed)*(mn-1);
		ed--;
	} else ed=n-1;
	int st=max(mn,mx-len+1);
	if(ed>=st) {
		int L=0;
		ans+=query_s(rt[v],lx,rx,st,ed,L);
		ans-=1ll*len*(ed-st+1);
	}
	return ans;
}

int main() {
	n=Get(),m=Get();
	for(int i=1;i<=n;i++) ss[i]=ss[i-1]+i;
	lx=1,rx=n;
	scanf("%s",s+1);
	for(int i=1;i<=n;i++) Insert(s[i]-'0',i);
	for(int i=2;i<=cnt;i++) {
		e[fail[i]].push_back(i);
		fa[i][0]=fail[i];
	}
	dfs(1);
	int l,r;
	while(m--) {
		l=Get(),r=Get();
		cout<<solve(Find(l,r),r-l+1)<<"\n";
	}
	return 0;
}

posted @ 2019-05-07 20:13  hec0411  阅读(322)  评论(0编辑  收藏  举报