CF1748E Yet Another Array Counting Problem 题解
题目大意
给定一个长度为 \(n\) 的序列 \(a\) 和 \(m\),定义 \([l,r]\) 的 最左边的最大值 为 最小的 \(l\le i\le r\) 满足 \(x_i=\max\{a_l,a_{l+1},\dots,a_r\}\)
求满足以下条件的长度为 \(n\) 的序列 \(b\) 的数量
- 对于所有的 \(1\le i\le n\),有 \(1\le b_i\le m\)
- 对于所有的 \(1\le l\le r\le n\),满足序列 \(a\) 的区间 \([l,r]\) 的 最左边的最大值 等于序列 \(b\) 的区间 \([l,r]\) 的 最左边的最大值
答案对 \(10^9+7\) 取模
\(2\le n,m\le 10^5\),\(1\le a_i\le m\),\(n\cdot m\le 10^6\),\(\sum n\cdot m\le 10^6\)
题目解析
讲个笑话,这个东西叫笛卡尔树,在HHHOJ模拟赛考了不下五次
——@275307894a
原来小丑竟是我自己。(HHHOJ 为某中学内部 OJ)
我们发现,对于序列 \(a\) 的任意一个区间 \([L,R]\),如果 \([L,R]\) 的 最左边的最大值 为 \(t\),那么不难发现:
- 对于 \(L\le l\le t\le r\le R\),\([l,r]\) 的 最左边的最大值 为 \(t\)
- 对于 \(L\le l\le r< t\),\([l,r]\) 的 最左边的最大值 \(p\) 满足 \(a_p<a_t\)
- 对于 \(L\le l\le r< t\),\([l,r]\) 的 最左边的最大值 \(q\) 满足 \(a_q\le a_t\)
根据以上性质,我们发现这样对于大小的限制就递归到了两个子区间内。
我们可以用一些数据结构(我写了个线段树)来快速查找任意子区间的 最左边的最大值 ,这样我们就可以把所有关于大小的约束建立起来。
不难发现最后所有的大小约束关系会变成一棵二叉树,然后在上面跑 DP 就可以了。设 \(f_{i,j}\) 为节点 \(i\) 为 \(j\) 的时候这棵子树的方案数,然后直接 DP 转移即可。\(f_{i,j}\) 计算完之后,需要计算 \(j\) 一维前缀和方便接下来的转移。
时间复杂度 \(O(nm+n\log n)\)
#include<cstdio>
#define db double
#define gc getchar
#define pc putchar
#define U unsigned
#define ll long long
#define ld long double
#define ull unsigned long long
#define Tp template<typename _T>
#define Me(a,b) memset(a,b,sizeof(a))
Tp _T mabs(_T a){ return a>0?a:-a; }
Tp _T mmax(_T a,_T b){ return a>b?a:b; }
Tp _T mmin(_T a,_T b){ return a<b?a:b; }
Tp void mswap(_T &a,_T &b){ _T tmp=a; a=b; b=tmp; return; }
Tp void print(_T x){ if(x<0) pc('-'),x=-x; if(x>9) print(x/10); pc((x%10)+48); return; }
#define EPS (1e-7)
#define INF (0x7fffffff)
#define LL_INF (0x7fffffffffffffff)
#define maxn 200039
#define maxm 1000039
#define MOD 1000000007
#define Type int
#ifndef ONLINE_JUDGE
//#define debug
#endif
using namespace std;
Type read(){
char c=gc(); Type s=0; int flag=0;
while((c<'0'||c>'9')&&c!='-') c=gc(); if(c=='-') c=gc(),flag=1;
while('0'<=c&&c<='9'){ s=(s<<1)+(s<<3)+(c^48); c=gc(); }
if(flag) return -s; return s;
}
class SGT{//Segment Tree
private:
struct JTZ{ int maxx,pos; }sum[maxn<<2];
JTZ merge(JTZ x,JTZ y){
if(x.maxx!=y.maxx) return x.maxx>y.maxx?x:y;
else return x.pos<y.pos?x:y;
}
int *arr,siz,L,R;
void up(int rt){ sum[rt]=merge(sum[rt<<1],sum[rt<<1|1]); return; }
void build(int l,int r,int rt){
if(l==r){ sum[rt].maxx=arr[l]; sum[rt].pos=l; return; } int mid=(l+r)>>1;
build(l,mid,rt<<1); build(mid+1,r,rt<<1|1); up(rt); return;
}
JTZ querymax(int l,int r,int rt){
if(L<=l&&r<=R) return sum[rt]; int mid=(l+r)>>1; JTZ tmp;
if(mid>=L){
tmp=querymax(l,mid,rt<<1);
if(mid<R) tmp=merge(tmp,querymax(mid+1,r,rt<<1|1));
} else if(mid<R) tmp=querymax(mid+1,r,rt<<1|1); return tmp;
}
public:
void init(int _siz,int *_arr){ siz=_siz; arr=_arr; build(1,siz,1); return; }
int query(int l,int r){ L=l,R=r; return querymax(1,siz,1).pos; }
}tr;
struct Node{ int ls,rs; }bt[maxn]; int siz;//Binary Tree
int build(int l,int r){
if(l>r) return 0; int mid=tr.query(l,r),tmp=++siz;
bt[tmp].ls=build(l,mid-1); bt[tmp].rs=build(mid+1,r); return tmp;
}
int n,m,a[maxn]; ll f[maxm];
#define get(i,j) (((i)-1)*m+(j))
void dfs(int x){
int i;
if(!bt[x].ls&&!bt[x].rs) for(i=1;i<=m;i++) f[get(x,i)]=i%MOD;
else if(!bt[x].ls&&bt[x].rs){
dfs(bt[x].rs);
for(i=1;i<=m;i++) f[get(x,i)]=f[get(bt[x].rs,i)];
for(i=2;i<=m;i++) f[get(x,i)]+=f[get(x,i-1)],f[get(x,i)]%=MOD;
} else if(bt[x].ls&&!bt[x].rs){
dfs(bt[x].ls); f[get(x,1)]=0;
for(i=2;i<=m;i++) f[get(x,i)]=f[get(bt[x].ls,i-1)];
for(i=2;i<=m;i++) f[get(x,i)]+=f[get(x,i-1)],f[get(x,i)]%=MOD;
} else{
dfs(bt[x].ls); dfs(bt[x].rs); f[get(x,1)]=0;
for(i=2;i<=m;i++) f[get(x,i)]=f[get(bt[x].ls,i-1)]*f[get(bt[x].rs,i)]%MOD;
for(i=2;i<=m;i++) f[get(x,i)]+=f[get(x,i-1)],f[get(x,i)]%=MOD;
} return;
}
void work(){
n=read(); m=read(); int i; for(i=1;i<=n;i++) a[i]=read();
tr.init(n,a); siz=0; build(1,n); for(i=1;i<=n*m;i++) f[i]=0;
dfs(1); print(f[m]),pc('\n'); return;
}
int main(){
#ifdef LOCAL
freopen("1.in","r",stdin);
#endif
int T=read(); while(T--) work(); return 0;
}