P4229-某位歌姬的故事【dp】
正题
题目链接:https://www.luogu.com.cn/problem/P4229
题目大意
求有多少个长度为\(n\)的序列\(a\),满足\(\forall i\in[1,n],a_i\in[1,A]\),还有\(Q\)个限制形如
\[\max\{a_j\}(j\in[l_i,r_i])=m_i
\]
\(1\leq n,A\leq 9\times 10^8,1\leq m_i\leq A,1\leq Q\leq 500,1\leq T\leq 20\)
解题思路
首先我们第一步肯定是把每个区间的端点提出来离散化,这样我们的区间数就是\(O(Q)\)级别的了。
然后考虑到对于两个有交的区间\([l_1,r_1]\)限制为\(m_1\),\([l_2,r_2]\)限制为\(m_2\),在\(m_1<m_2\)时这两个区间交的那一部分显然不会对第二个区间产生影响,因为这个区间肯定合法并且不能是最大值。
那么我们考虑求出每个区间能够到达的最大值\(lim_i\),然后对一个所有的限制\([l,r,w]\),我们都只需要考虑\(lim_i=w\)的区间。
现在相当于对于每个单独的小区间我们可以选择上到最大值或者没有最大值。然后要求是每个区间至少有一个最大值。
考虑\(dp\),设\(f_{i,j}\)表示现在做到第\(i\)个区间,上一个顶到最大值的区间是\(j\)时的方案,因为我们只处理\(lim_i=w\)的区间,所以一个区间最多被做一次。
时间复杂度:\(O(TQ^2)\)
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const ll N=2100,P=998244353;
struct node{
ll l,r,w;
}q[N];
ll T,n,m,A,tot,cnt,b[N];
ll wc[N],bc[N],lim[N],len[N];
ll f[N][N],rim[N],loc[N],pos[N];
ll power(ll x,ll b){
ll ans=1;
while(b){
if(b&1)ans=ans*x%P;
x=x*x%P;b>>=1;
}
return ans;
}
bool cmp(node x,node y){
if(x.w!=y.w)return x.w<y.w;
if(x.l!=y.l)return x.l<y.l;
return x.r<y.r;
}
ll calc(ll w,ll L,ll R){
tot=0;
for(ll i=1;i<=cnt;i++){
if(lim[i]==w)
loc[++tot]=i,rim[tot]=0;
pos[i]=tot;
}
for(ll i=L;i<=R;i++){
q[i].r=pos[q[i].r];
if(loc[pos[q[i].l]]!=q[i].l)
q[i].l=pos[q[i].l]+1;
else q[i].l=pos[q[i].l];
rim[q[i].r]=max(rim[q[i].r],q[i].l);
}
ll r=0;f[0][0]=1;
for(ll i=1;i<=tot;i++){
for(ll j=0;j<=i;j++)f[i][j]=0;
for(ll j=rim[i];j<i;j++)
(f[i][j]+=f[i-1][j]*wc[loc[i]]%P)%=P;
for(ll j=0;j<i;j++)
(f[i][i]+=f[i-1][j]*bc[loc[i]]%P)%=P;
}
ll ans=0;
for(ll j=0;j<=tot;j++)
(ans+=f[tot][j])%=P;
return ans;
}
void solve(){
scanf("%lld%lld%lld",&n,&m,&A);
for(ll i=1;i<=m;i++){
scanf("%lld%lld%lld",&q[i].l,&q[i].r,&q[i].w);
q[i].r++;b[++cnt]=q[i].l;b[++cnt]=q[i].r;
}
b[++cnt]=1;b[++cnt]=n+1;
sort(b+1,b+1+cnt);
sort(q+1,q+1+m,cmp);
cnt=unique(b+1,b+1+cnt)-b-1;
for(ll i=1;i<=cnt;i++)lim[i]=A+1;
for(ll i=1;i<=m;i++){
q[i].l=lower_bound(b+1,b+1+cnt,q[i].l)-b;
q[i].r=lower_bound(b+1,b+1+cnt,q[i].r)-b-1;
bool flag=false;
for(ll j=q[i].l;j<=q[i].r;j++){
if(lim[j]>=q[i].w)flag=true;
lim[j]=min(lim[j],q[i].w);
}
if(!flag){puts("0");return;}
}
ll ans=1;
for(ll i=1;i<cnt;i++){
len[i]=b[i+1]-b[i];
if(lim[i]==A+1)ans=ans*power(A,len[i])%P;
wc[i]=power(lim[i]-1,len[i]);
bc[i]=(power(lim[i],len[i])-wc[i]+P)%P;
}
ll L,R=0;cnt--;
while(R<m){
L=R+1;R=L;
while(R<m&&q[R+1].w==q[L].w)R++;
ans=ans*calc(q[L].w,L,R)%P;
}
printf("%lld\n",ans);
return;
}
signed main()
{
scanf("%lld",&T);
while(T--){
cnt=0;solve();
}
return 0;
}