CF1621G Weighted Increasing Subsequences 题解
Statement
给定⼀个⻓度为 \(n\) 的数列 。
⼀个严格上升⼦序列 \(a_{p_1},a_{p_2},\dots,a_{p_k}\) 的权值被定义为
求所有严格上升子序列的权值和模 \(1e9+7\)
多组数据 ,\(\sum n\le 2\times 10^5\)
Solution
发现我们很难整个子序列地算贡献,所以我们考虑每一个 \(a_x\) 的贡献
假装没有那个限制,贡献显然是 \(f[x]\times g[x]\) ,其中 \(f[x]\) 和 \(g[x]\) 分别表示以 \(x\) 为上升子序列末端/开端 的子序列个数,\(f,g\) 容易使用树状数组优化 DP 在 \(O(n\log n)\) 求出
考虑这个限制,设 \(y\) 为最大的 \(p\in(x,n)\) 满足 \(a_p>a_x\) 的位置,即 \(y=\max\{p\}\)
由 \(y\) 的定义,我们知道它是一个后缀最大值,且所有在 \(a_y\) 后面的数都小于 \(x\)
这个限制即是说包含 \(x\) 的上升子序列不能以 \(y\) 为结尾
当然,包含 \(x\) 的上升序列不可能以 \(a_y\) 之后的数结尾
\(y\) 可以考虑二分求出
所以我们现在考虑求出 \(h[x]\) 表示以 \(x\) 开头,\(x\) 对应的 \(y\) 结尾的子序列个数,显然一个 \(x\) 对应一个 \(y\)
加入我们一个个 \(x\) 枚举过去求解的话依旧是要 T 的,这个时候我们考虑对于一个 \(x\) 而言,以他为开端的上升序列要想以 \(y\) 为结尾,可选择的位置 \(p\) 都应该满足 \(a_x<a_p\) 且 \(p\) 所对应的 \(y\) 和 \(x\) 所对应的 \(y\) 相同
所以我们把对应同一个 \(y\) 的数单独提出来一起求 \(h\) ,即枚举 \(y\)
最后答案就是 \(\sum f[x]\times(g[x]-h[x])\)
\(O(n\log n)\)
Code
#include<bits/stdc++.h>
#define lowbit(x) (-x&x)
#define int long long
using namespace std;
typedef long long ll;
const int N = 2e5+5;
const int mod = 1e9+7;
int read(){
int s=0,w=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')w=-1;ch=getchar();}
while(isdigit(ch))s=s*10+(ch^48),ch=getchar();
return s*w;
}
void inc(int&a,int b){a=a+b>=mod?a+b-mod:a+b;}
void dec(int&a,int b){a=a>=b?a-b:a+mod-b;}
struct BIT{
int c[N];
void clean(){memset(c,0,sizeof(c));}
void add(int x,int v){for(;x<N;x+=lowbit(x))inc(c[x],v);}
int ask(int x){int r=0;for(;x;x-=lowbit(x))inc(r,c[x]);return r;}
}bit;
int a[N],f[N],g[N],h[N],ord[N],stk[N];
vector<int>seq[N];
int T,n,top;
signed main(){
T=read();
while(T--){
n=read(),top=0;
for(int i=1;i<=n;++i)a[i]=read(),ord[i]=i;
sort(ord+1,ord+1+n,[](int x,int y){
return a[x]==a[y]?x>y:a[x]<a[y];
});
for(int i=1;i<=n;++i)a[ord[i]]=i;
bit.clean(); for(int i=1;i<=n;++i)f[i]=(bit.ask(a[i])+1)%mod,bit.add(a[i],f[i]);
bit.clean(); for(int i=n;i>=1;--i)g[i]=(bit.ask(n-a[i]+1)+1)%mod,bit.add(n-a[i]+1,g[i]);
for(int i=n;i>=1;--i)if(a[i]>a[stk[top]])stk[++top]=i;
for(int i=1;i<=top;++i)seq[i].clear();
for(int i=n;i>=1;--i){
int l=1,r=top;
while(l<r){
int mid=l+(r-l)/2;
if(a[i]<=a[stk[mid]])r=mid;
else l=mid+1;
}
if(i^stk[r])seq[r].emplace_back(i);
}
bit.clean();//清空!
for(int i=1;i<=top;++i){
bit.add(n-a[stk[i]]+1,h[stk[i]]=1);
for(auto v:seq[i])h[v]=bit.ask(n-a[v]+1),bit.add(n-a[v]+1,h[v]);//
for(auto v:seq[i])bit.add(n-a[v]+1,mod-h[v]);
bit.add(n-a[stk[i]]+1,mod-1);
}
int ans=0;
for(int i=1;i<=n;++i)
inc(ans,(g[i]-h[i]+mod)%mod*f[i]%mod);
printf("%lld\n",ans);
}
return 0;
}