[atcoder utpc2023_p] Priority Queue 3

Priority Queue 3

题意:有一个小根堆和 \(1\) ~ \(n\) 个数,以及一个操作序列,+ 表示 \(push\), - 表示 \(pop\)\(pop\)\(m\) 次,问你有多少种插入顺序使得最后的 pop 集合与给出的的数字集合 \(Y\) 相同。

首先有个浅显的发现:对于不在 \(Y\) 集合中的数,可选范围形如一个阶梯,换句话说,就是可选范围为 \([l_i,n]\)\(\forall_{i < m},l_i > l_{i+1}\)

设集合 \(Y\) 从小到大第 \(i\) 个元素为 \(Y_i\)

所以有 \(dp_{i,j,k,t \in0,1}\) 表示现在在操作序列的第 \(i\) 个符号处,现在元素的可选范围为 \([Y_j,n]\),现在堆里有 \(k\) 个元素是在 \(Y\) 集合中的,\(0/1\) 表示 \(Y_j\) 是否被加入进了堆中。

首先我们考虑 + 操作的 dp 转移。

  • 若加入进去的数不在 \(Y\) 集合中 :\(val(i,j,k) \times dp_{i,j,k,t} \to dp_{i+1,j,k,t}\)\(val(i,j,k)\) 表示的是系数)

  • 若加入进去的数在 \(Y\) 集合中

    • 加入的数为 \(Y_j\)\(dp_{i,j,k,0} \to dp_{i+1,j,k+1,1}\)
    • 加入的数不为 \(Y_j\)\(val^{'}(i,j,k) \times dp_{i,j,k,t} \to dp_{i+1,j,k+1,t}\)

对于 \(val(i,j,k)\),我们发现这是好算的,若前面已经 \(push\)\(x\) 次,\(pop\)\(y\) 次(下面的 \(x\)\(y\) 是相同意义),那么 \(val(i,j,k) = (n - V_j) - (m - j) - (x - y - k)\)

但是我们发现 \(val^{'}(i,j,k)\) 就并不是那么好求了,这时候就有一种很妙的方法:就是在插入的时候不去考虑插入了哪个数,在 \(pop\) 时再去考虑顺序吗,也就是说 \(val^{'}(i,j,k)\) 不需要在 push 时不考虑。

然后是 - 操作的 dp 转移。

  • \(k>1\)\(k=1,t=0\),这时候 \(j\) 不会改变 :\(dp_{i,j,k,t} \to dp{i+1,j,k-1,t}\)

  • \(k=1,t=1\),这是关键转移,我们枚举在把 \(V_j\) 踢掉之后,下一个限制范围为 \(p\),则有转移:

\(val^{''}(i,j,1) \times dp_{i,j,1,1} \to dp_{i+1,p,0,0}\)\(val^{''}(i,j,1) = \binom{y-(m-j)}{j-p-1} \times (j-p-1)!\)

解释一下 \(val^{''}(i,j,1)\) 的含义:因为如果要到 \(p\) 这一个位置上的话,就要保证 \([p+1,j-1]\) 的元素都被删除了,所以是这个形式。

点击查看代码
#include<bits/stdc++.h>
#define fir first
#define sec second
#define int long long
#define mkp(a,b) make_pair(a,b)
using namespace std;
typedef pair<int,int> pir;
inline int read(){
	int x=0,f=1; char c=getchar();
	while(!isdigit(c)){if(c=='-') f=-1; c=getchar();}
	while(isdigit(c)){x=x*10+(c^48); c=getchar();}
	return x*f;
}
const int mod=998244353,inf=1e18,N=305;
int n,m;
char s[N*2];
int a[N],dp[N][N][2],tmp[N][N][2],C[N][N],jie[N];
inline void init(){
	C[0][0]=1;
	for(int i=1;i<=m;i++){
		C[i][0]=1;
		for(int j=1;j<=i;j++)
		C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod;
	}
	jie[0]=1; for(int i=1;i<=m;i++) jie[i]=jie[i-1]*i%mod; 
}
signed main(){
	freopen("heap.in","r",stdin);
	freopen("heap.out","w",stdout); 
	n=read(),m=read();
	scanf("%s",s);
	for(int i=1;i<=m;i++) a[i]=read(); 
	init(); sort(a+1,a+m+1);
	dp[m][0][0]=1;
	int pu=0,po=0;
	for(int i=0;i<n+m;i++){
		memcpy(tmp,dp,sizeof(dp));
		memset(dp,0,sizeof(dp));
		if(s[i]=='+'){
			for(int j=0;j<=m;j++) for(int k=0;k<=j;k++){
				(dp[j][k+1][0]+=tmp[j][k][0])%=mod;
				(dp[j][k+1][1]+=tmp[j][k][0])%=mod;
				(dp[j][k+1][1]+=tmp[j][k][1])%=mod;
				int val=n-a[j]-(m-j)-(pu-po-k);
				(dp[j][k][0]+=val*tmp[j][k][0])%=mod;
				(dp[j][k][1]+=val*tmp[j][k][1])%=mod;
			}
			pu++;
		}
		else{
			for(int j=1;j<=m;j++){
				for(int k=1;k<=j;k++){
					if(k!=1) (dp[j][k-1][1]+=tmp[j][k][1])%=mod;
					(dp[j][k-1][0]+=tmp[j][k][0])%=mod;
				}
				int t=po-(m-j);
				for(int p=j-t-1;p<j;p++)
				(dp[p][0][0]+=C[t][j-p-1]*jie[j-p-1]%mod*tmp[j][1][1])%=mod;
				
			}
			po++;
		}
//		for(int j=0;j<=m;j++) for(int k=0;k<=pu;k++) cout<<j<<' '<<k<<' '<<dp[j][k][0]<<' '<<dp[j][k][1]<<'\n';
//		puts("\n");
	}
	cout<<dp[0][0][0]<<'\n';
}
posted @ 2024-07-24 16:32  ~Cyan~  阅读(2)  评论(0编辑  收藏  举报