P4755-Beautiful Pair【笛卡尔树,线段树】
正题
题目链接:https://www.luogu.com.cn/problem/P4755
题目大意
\(n\)个数字的一个序列,求有多少个点对\(i,j\)满足\(a_i\times a_j\leq max\{a_k\}(k\in[l,r])\)
解题思路
如果构建一棵笛卡尔树的话那么两个点之间的\(max\)就在笛卡尔树的\(LCA\)位置。
所以对于每个位置维护一个线段树,然后每次暴力枚举小的那棵子树在大子树的线段树中查询即可。然后线段树合并或者启发式合并上去就好了。
建笛卡尔树的时候用\(\text{RMQ}\)查询区间最大值然后递归下去就好了。
当然因为是乘法所以小的那个值域不会超过\(\sqrt{10^9}\)所以也可以树状数组+启发式合并。
这里写的是线段树的做法,时间复杂度都是\(O(n\log^2 n)\)
注意\(1\)要特判就好了
code
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1e5+10,L=20;
int n,a[N],lg[N],f[N][L+1],inf;
long long ans;
struct Seg_Tree{
int cnt,w[N<<6],ls[N<<6],rs[N<<6];
void Change(int &x,int L,int R,int pos,int val){
if(!x)x=++cnt;w[x]+=val;
if(L==R)return;
int mid=(L+R)>>1;
if(pos<=mid)Change(ls[x],L,mid,pos,val);
else Change(rs[x],mid+1,R,pos,val);
return;
}
int Ask(int x,int L,int R,int l,int r){
if(!x||l>r)return 0;
if(L==l&&R==r)return w[x];
int mid=(L+R)>>1;
if(r<=mid)return Ask(ls[x],L,mid,l,r);
if(l>mid)return Ask(rs[x],mid+1,R,l,r);
return Ask(ls[x],L,mid,l,mid)+Ask(rs[x],mid+1,R,mid+1,r);
}
int Merge(int x,int y,int L,int R){
if(!x||!y)return x+y;
int mid=(L+R)>>1;w[x]+=w[y];
if(L==R)return x;
ls[x]=Merge(ls[x],ls[y],L,mid);
rs[x]=Merge(rs[x],rs[y],mid+1,R);
return x;
}
}T;
int Ask(int l,int r){
int z=lg[r-l+1];
int x=f[l][z],y=f[r-(1<<z)+1][z];
return (a[x]>=a[y])?x:y;
}
int solve(int l,int r){
if(l>r)return 0;
int x=Ask(l,r),ls,rs;
ls=solve(l,x-1);
rs=solve(x+1,r);
if(ls)ans+=T.Ask(ls,1,inf,1,1);
if(rs)ans+=T.Ask(rs,1,inf,1,1);
if(x-l<=r-x){
for(int i=l;i<x;i++)
ans+=T.Ask(rs,1,inf,1,a[x]/a[i]);
}
else{
for(int i=x+1;i<=r;i++)
ans+=T.Ask(ls,1,inf,1,a[x]/a[i]);
}
ls=T.Merge(ls,rs,1,inf);
T.Change(ls,1,inf,a[x],1);
return ls;
}
int main()
{
// printf("%d\n",sizeof(T)>>20);
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
inf=max(inf,a[i]);
ans+=(a[i]==1);
f[i][0]=i;
}
inf=1e9;
for(int i=2;i<=n;i++)lg[i]=lg[i>>1]+1;
for(int j=1;(1<<j)<=n;j++)
for(int i=1;i+(1<<j)-1<=n;i++){
int x=f[i][j-1],y=f[i+(1<<j-1)][j-1];
if(a[x]>=a[y])f[i][j]=x;
else f[i][j]=y;
}
solve(1,n);
printf("%lld",ans);
return 0;
}