CF1591F 题解

先不管值域,设计状态 \(dp_{i,j}\) 表示考虑前 \(i\) 个数最后一个数为 \(j\) 的方案数,那么有如下转移:

\[dp_{i,j} = dp_{i-1,k} (j \not = k,j \leq a_i) \]

先滚动数组去掉一维状态,然后发现每一次操作对于数组 \(dp\) 而言其实是对于 \(j \leq a_i\)\(dp_{j}\) 变成 \(x - dp_j\) 这里 \(x\) 代表所有 \(dp_i\) 的总和,使得 \(j > a_i\)\(dp_j\) 变为 \(0\)

因此考虑用线段树维护 \(dp\) 数组即可。

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int maxn = 2e5+114;
const int top = 1e9+7;
const int mod = 998244353;
struct Node{
	int sum,tag1,tag2,ls,rs,lt,rt;
}tr[maxn*40];
int tot,rt;
void pushup(int cur){
	tr[cur].sum=(tr[tr[cur].ls].sum+tr[tr[cur].rs].sum)%mod;
}
void pushdown1(int cur){
	if(tr[cur].tag1!=0){
		int mid=(tr[cur].lt+tr[cur].rt)>>1;
		if(tr[cur].ls==0) tr[cur].ls=++tot,tr[tr[cur].ls].tag2=1,tr[tr[cur].ls].lt=tr[cur].lt,tr[tr[cur].ls].rt=mid;
		if(tr[cur].rs==0) tr[cur].rs=++tot,tr[tr[cur].rs].tag2=1,tr[tr[cur].rs].lt=mid+1,tr[tr[cur].rs].rt=tr[cur].rt;
		tr[tr[cur].ls].sum=(tr[tr[cur].ls].sum+tr[cur].tag1*(mid-tr[cur].lt+1))%mod;
		tr[tr[cur].rs].sum=(tr[tr[cur].rs].sum+tr[cur].tag1*(tr[cur].rt-mid))%mod;
		tr[tr[cur].ls].tag1=(tr[tr[cur].ls].tag1+tr[cur].tag1)%mod;
		tr[tr[cur].rs].tag1=(tr[tr[cur].rs].tag1+tr[cur].tag1)%mod;
		tr[cur].tag1=0;
	}
}
void pushdown2(int cur){
	if(tr[cur].tag2!=1){
		int mid=(tr[cur].lt+tr[cur].rt)>>1;
		if(tr[cur].ls==0) tr[cur].ls=++tot,tr[tr[cur].ls].tag2=1,tr[tr[cur].ls].lt=tr[cur].lt,tr[tr[cur].ls].rt=mid;
		if(tr[cur].rs==0) tr[cur].rs=++tot,tr[tr[cur].rs].tag2=1,tr[tr[cur].rs].lt=mid+1,tr[tr[cur].rs].rt=tr[cur].rt;
		tr[tr[cur].ls].sum=(tr[tr[cur].ls].sum*tr[cur].tag2)%mod;
		tr[tr[cur].rs].sum=(tr[tr[cur].rs].sum*tr[cur].tag2)%mod;
		tr[tr[cur].ls].tag1=(tr[tr[cur].ls].tag1*tr[cur].tag2)%mod;
		tr[tr[cur].rs].tag1=(tr[tr[cur].rs].tag1*tr[cur].tag2)%mod;
		tr[tr[cur].ls].tag2=(tr[tr[cur].ls].tag2*tr[cur].tag2)%mod;
		tr[tr[cur].rs].tag2=(tr[tr[cur].rs].tag2*tr[cur].tag2)%mod;
		tr[cur].tag2=1;
	}
}
void pushdown(int cur){
	pushdown2(cur);
	pushdown1(cur);
}
void update1(int &cur,int lt,int rt,int l,int r,int v){
	if(lt>r||rt<l) return ;
	if(cur==0){
		cur=++tot;
		tr[cur].lt=lt,tr[cur].rt=rt,tr[cur].tag2=1,tr[cur].tag1=0;		
	}
	if(l<=lt&&rt<=r){
		tr[cur].sum+=((v%mod)*(rt-lt+1)%mod);
		tr[cur].sum%=mod;
		tr[cur].tag1+=v;
		tr[cur].tag1%=mod;
		return ;
	}	
	pushdown(cur);
	int mid=(lt+rt)>>1;
	update1(tr[cur].ls,lt,mid,l,r,v);
	update1(tr[cur].rs,mid+1,rt,l,r,v);
	pushup(cur);
}
void update2(int &cur,int lt,int rt,int l,int r,int v){
	if(lt>r||rt<l) return ;
	if(cur==0){
		cur=++tot;
		tr[cur].lt=lt,tr[cur].rt=rt,tr[cur].tag2=1,tr[cur].tag1=0;	
	}
	if(l<=lt&&rt<=r){
		tr[cur].sum*=v;
		tr[cur].sum%=mod;
		tr[cur].tag2*=v;
		tr[cur].tag2%=mod;
		tr[cur].tag1*=v;
		tr[cur].tag1%=mod;
		return ;
	}	
	pushdown(cur);
	int mid=(lt+rt)>>1;
	update2(tr[cur].ls,lt,mid,l,r,v);
	update2(tr[cur].rs,mid+1,rt,l,r,v);
	pushup(cur);
}
int a[maxn],n;
signed main(){
cin>>n;
for(int i=1;i<=n;i++) cin>>a[i];
update1(rt,1,top,1,a[1],1);
for(int i=2;i<=n;i++){
	int sum=tr[rt].sum%mod;
	update2(rt,1,top,1,a[i],-1);
	update1(rt,1,top,1,a[i],sum);
	update2(rt,1,top,a[i]+1,top,0);
} 
cout<<(tr[rt].sum+mod)%mod;
return 0;
}
posted @ 2024-01-30 23:57  ChiFAN鸭  阅读(3)  评论(0编辑  收藏  举报