CF1591F 题解
先不管值域,设计状态 \(dp_{i,j}\) 表示考虑前 \(i\) 个数最后一个数为 \(j\) 的方案数,那么有如下转移:
\[dp_{i,j} = dp_{i-1,k} (j \not = k,j \leq a_i)
\]
先滚动数组去掉一维状态,然后发现每一次操作对于数组 \(dp\) 而言其实是对于 \(j \leq a_i\) 的 \(dp_{j}\) 变成 \(x - dp_j\) 这里 \(x\) 代表所有 \(dp_i\) 的总和,使得 \(j > a_i\) 的 \(dp_j\) 变为 \(0\)。
因此考虑用线段树维护 \(dp\) 数组即可。
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int maxn = 2e5+114;
const int top = 1e9+7;
const int mod = 998244353;
struct Node{
int sum,tag1,tag2,ls,rs,lt,rt;
}tr[maxn*40];
int tot,rt;
void pushup(int cur){
tr[cur].sum=(tr[tr[cur].ls].sum+tr[tr[cur].rs].sum)%mod;
}
void pushdown1(int cur){
if(tr[cur].tag1!=0){
int mid=(tr[cur].lt+tr[cur].rt)>>1;
if(tr[cur].ls==0) tr[cur].ls=++tot,tr[tr[cur].ls].tag2=1,tr[tr[cur].ls].lt=tr[cur].lt,tr[tr[cur].ls].rt=mid;
if(tr[cur].rs==0) tr[cur].rs=++tot,tr[tr[cur].rs].tag2=1,tr[tr[cur].rs].lt=mid+1,tr[tr[cur].rs].rt=tr[cur].rt;
tr[tr[cur].ls].sum=(tr[tr[cur].ls].sum+tr[cur].tag1*(mid-tr[cur].lt+1))%mod;
tr[tr[cur].rs].sum=(tr[tr[cur].rs].sum+tr[cur].tag1*(tr[cur].rt-mid))%mod;
tr[tr[cur].ls].tag1=(tr[tr[cur].ls].tag1+tr[cur].tag1)%mod;
tr[tr[cur].rs].tag1=(tr[tr[cur].rs].tag1+tr[cur].tag1)%mod;
tr[cur].tag1=0;
}
}
void pushdown2(int cur){
if(tr[cur].tag2!=1){
int mid=(tr[cur].lt+tr[cur].rt)>>1;
if(tr[cur].ls==0) tr[cur].ls=++tot,tr[tr[cur].ls].tag2=1,tr[tr[cur].ls].lt=tr[cur].lt,tr[tr[cur].ls].rt=mid;
if(tr[cur].rs==0) tr[cur].rs=++tot,tr[tr[cur].rs].tag2=1,tr[tr[cur].rs].lt=mid+1,tr[tr[cur].rs].rt=tr[cur].rt;
tr[tr[cur].ls].sum=(tr[tr[cur].ls].sum*tr[cur].tag2)%mod;
tr[tr[cur].rs].sum=(tr[tr[cur].rs].sum*tr[cur].tag2)%mod;
tr[tr[cur].ls].tag1=(tr[tr[cur].ls].tag1*tr[cur].tag2)%mod;
tr[tr[cur].rs].tag1=(tr[tr[cur].rs].tag1*tr[cur].tag2)%mod;
tr[tr[cur].ls].tag2=(tr[tr[cur].ls].tag2*tr[cur].tag2)%mod;
tr[tr[cur].rs].tag2=(tr[tr[cur].rs].tag2*tr[cur].tag2)%mod;
tr[cur].tag2=1;
}
}
void pushdown(int cur){
pushdown2(cur);
pushdown1(cur);
}
void update1(int &cur,int lt,int rt,int l,int r,int v){
if(lt>r||rt<l) return ;
if(cur==0){
cur=++tot;
tr[cur].lt=lt,tr[cur].rt=rt,tr[cur].tag2=1,tr[cur].tag1=0;
}
if(l<=lt&&rt<=r){
tr[cur].sum+=((v%mod)*(rt-lt+1)%mod);
tr[cur].sum%=mod;
tr[cur].tag1+=v;
tr[cur].tag1%=mod;
return ;
}
pushdown(cur);
int mid=(lt+rt)>>1;
update1(tr[cur].ls,lt,mid,l,r,v);
update1(tr[cur].rs,mid+1,rt,l,r,v);
pushup(cur);
}
void update2(int &cur,int lt,int rt,int l,int r,int v){
if(lt>r||rt<l) return ;
if(cur==0){
cur=++tot;
tr[cur].lt=lt,tr[cur].rt=rt,tr[cur].tag2=1,tr[cur].tag1=0;
}
if(l<=lt&&rt<=r){
tr[cur].sum*=v;
tr[cur].sum%=mod;
tr[cur].tag2*=v;
tr[cur].tag2%=mod;
tr[cur].tag1*=v;
tr[cur].tag1%=mod;
return ;
}
pushdown(cur);
int mid=(lt+rt)>>1;
update2(tr[cur].ls,lt,mid,l,r,v);
update2(tr[cur].rs,mid+1,rt,l,r,v);
pushup(cur);
}
int a[maxn],n;
signed main(){
cin>>n;
for(int i=1;i<=n;i++) cin>>a[i];
update1(rt,1,top,1,a[1],1);
for(int i=2;i<=n;i++){
int sum=tr[rt].sum%mod;
update2(rt,1,top,1,a[i],-1);
update1(rt,1,top,1,a[i],sum);
update2(rt,1,top,a[i]+1,top,0);
}
cout<<(tr[rt].sum+mod)%mod;
return 0;
}