[BZOJ5351]Query on a sequence
[BZOJ5351]Query on a sequence
题目大意:
给定一个长度为\(n(n\le10^5)\)的数列\(P\),满足\(|P_i|\le10^9\),求满足下列约束的不同的四元组\((a,b,c,d)\)的个数:
- \(1\le a\le b<c\le d\le n\);
- \(b-a=d-c\);
- \(c-b-1=m\),\(m(m>0)\)为给定的数;
- \(p_{a+i}=P_{c+i}\)对于所有\(i(0\le i\le b-a)\)均成立。
思路:
对于\(a\)和\(c\),若\(c-a>m\),且对于后缀\(S_a\)和\(S_c\),有\(\operatorname{lcp}(S_a,S_c)\ge c-a-m\),则我们可以求出合法的\(b\)和\(d\)的区间。
首先对于数列建立后缀数组。从大到小枚举\(lcp[i]\),合并其对应的两个后缀\(a\)和\(c\),因为若下一条枚举到的\(lcp[j]\)如果是关于\(c\)的,则由于\(lcp[i]\ge lcp[i]\),它一样可以适用于\(a\)。对于每个连通块用线段树维护数列每个区间内的数开头的每个后缀是否出现。统计时枚举连通块内每一个后缀,枚举它是\(a\)还是\(c\),在线段树上查找区间和即可。
源代码:
#include<cstdio>
#include<cctype>
#include<climits>
#include<algorithm>
inline int getint() {
register char ch;
register bool neg=false;
while(!isdigit(ch=getchar())) neg|=ch=='-';
register int x=ch^'0';
while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0');
return x;
}
const int N=1e5+1;
int n,m,k,s[N],sa[N],rank[N],tmp[N];
std::pair<int,int> lcp[N];
inline bool cmp(const int &i,const int &j) {
if(rank[i]!=rank[j]) return rank[i]<rank[j];
const int ri=i+k<=n?rank[i+k]:-1;
const int rj=j+k<=n?rank[j+k]:-1;
return ri<rj;
}
inline void suffix_sort() {
for(register int i=0;i<=n;i++) {
sa[i]=i;
rank[i]=s[i];
}
for(k=1;k<=n;k<<=1) {
std::sort(&sa[0],&sa[n+1],cmp);
tmp[sa[0]]=0;
for(register int i=1;i<=n;i++) {
tmp[sa[i]]=tmp[sa[i-1]]+!!cmp(sa[i-1],sa[i]);
}
std::copy(&tmp[0],&tmp[n]+1,rank);
}
};
inline void init_lcp() {
for(register int i=0,h=0;i<n;i++) {
if(h>0) h--;
const int &j=sa[rank[i]-1];
while(i+h<n&&j+h<n&&s[i+h]==s[j+h]) h++;
lcp[rank[i]-1]=std::make_pair(-h,rank[i]);
}
}
const int SIZE=N*20;
class SegmentTree {
private:
struct Node {
int val,left,right;
};
Node node[SIZE];
int sz,new_node() {
return ++sz;
}
public:
int root[N];
void insert(int &p,const int &b,const int &e,const int &x) {
node[p=new_node()].val++;
if(b==e) return;
const int mid=(b+e)>>1;
if(x<=mid) insert(node[p].left,b,mid,x);
if(x>mid) insert(node[p].right,mid+1,e,x);
}
int query(const int &p,const int &b,const int &e,const int &l,const int &r) const {
if(r<l) return 0;
if(b==l&&e==r) return node[p].val;
const int mid=(b+e)>>1;
int ret=0;
if(l<=mid) ret+=query(node[p].left,b,mid,l,std::min(mid,r));
if(r>mid) ret+=query(node[p].right,mid+1,e,std::max(mid+1,l),r);
return ret;
}
int merge(const int &x,const int &y) {
if(!x||!y) return x|y;
node[y].val+=node[x].val;
node[y].left=merge(node[x].left,node[y].left);
node[y].right=merge(node[x].right,node[y].right);
return y;
}
};
SegmentTree t;
struct DisjointSet {
int anc[N],min[N],size[N];
int find(const int &x) {
return x==anc[x]?x:anc[x]=find(anc[x]);
}
void reset() {
for(register int i=1;i<=n;i++) {
size[i]=1;
anc[i]=min[i]=i;
}
}
void merge(const int &x,const int &y) {
anc[x]=y;
min[y]=std::min(min[x],min[y]);
size[y]+=size[x];
}
};
DisjointSet djs;
int lim,ans;
inline void merge(int x,int y) {
x=djs.find(x),y=djs.find(y);
if(djs.size[x]>djs.size[y]) std::swap(x,y);
for(register int i=djs.min[x];i<djs.min[x]+djs.size[x];i++) {
ans+=t.query(t.root[y],1,n,std::max(sa[i]+1+m+1,1),std::min(sa[i]+1+lim,n));
ans+=t.query(t.root[y],1,n,std::max(sa[i]+1-lim,1),std::min(sa[i]+1-m-1,n));
}
djs.merge(x,y);
t.merge(t.root[x],t.root[y]);
}
int main() {
n=getint(),m=getint();
for(register int i=0;i<n;i++) s[i]=getint();
s[n]=INT_MIN;
suffix_sort();
init_lcp();
std::sort(&lcp[0],&lcp[n]);
for(register int i=1;i<=n;i++) {
t.insert(t.root[i],1,n,sa[i]+1);
}
djs.reset();
for(register int i=0;i<n;i++) {
lim=-lcp[i].first+m;
merge(lcp[i].second,lcp[i].second-1);
}
printf("%d\n",ans);
return 0;
}