题解 張士超你昨天晚上到底把我家鑰匙放在哪了

传送门

不是自己想出来的 DP 写着就是炸就完了

首先 \(\sum a_i\) 较小的时候可以按题意暴力 DP

然后正解大概可以有这样一个引入:
我们先尝试枚举分配钥匙的方案,再枚举钦定被 find 的位置
直接分配钥匙会有不合法的,需要加一步容斥,钦定位置放 \(a_i+1\)
于是有一个 \(O(4^n)\) 的状压
答案为 \(\sum\limits_s(-1)^{|s|}\prod(p_i\ or\ (1-p_i))\) × 分配剩下钥匙的方案数
然后优化一下,发现可以 DP,将 \((-1)^{|s|}\)\(\prod\) 都算到权值里
康康为了算分配剩下钥匙的方案数我们都需要什么
需要钦定放多的且被 find 的钥匙数,钦定放多且未被 find 的钥匙数和钦定被 find 的位置数
于是 DP 定义和转移都出来了
接下来算分配剩下钥匙的方案数
剩下的钥匙中有 \(n_1\) 把是要被 find 的,\(n_2\) 把无所谓
\(m_1\) 个位置是被 find 了的,有 \(m_2\) 个位置没有
问题等价于将 \(n_1+n_2\) 个球放到 \(m_1+m_2\) 个位置使前 \(m_1\) 个位置至少有 \(n_1\) 个球
显然有一个 \(O(n1+n2)\) 的做法,但我们要让复杂度只和 \(m\) 相关
有一个比较神仙的思路是枚举第 \(n_1\) 个数被放在哪了
就有

\[\sum\limits_{k=1}^{m_1}\binom{n_1-1+k-1}{k-1}\binom{n_2+m-k}{m-k} \]

也可以通过代数推导但超出我能力范围了
最终的复杂度是 \(O(n^4)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define ll long long
#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int m, d, N, n;
int a[110], p[110];
const ll mod=998244353;

namespace task1{
	ll f[110][110][110];
	void solve() {
		f[0][0][0]=1;
		for (int i=1; i<=m; ++i) {
			for (int j=0; j<=N; ++j) {
				for (int k=0; k<=j; ++k) {
					for (int l=0; l<=min(a[i], j); ++l) {
						if (l<=k) f[i][j][k]+=p[i]*f[i-1][j-l][k-l];
						f[i][j][k]+=(1-p[i])*f[i-1][j-l][k];
						f[i][j][k]%=mod;
					}
				}
			}
		}
		ll ans=0;
		for (int i=n; i<=N; ++i) ans=(ans+f[m][N][i])%mod;
		printf("%lld\n", (ans%mod+mod)%mod);
	}
}

namespace task{
	int f[2][110][110][110], inv[110], sta[110], top, ans;
	int C(int n, int k) {
		//cout<<"C: "<<n<<' '<<k<<endl;
		//if (n<k) return 0;
		//if (n==k) return 1;
		int ans=inv[k];
		for (int i=n; i>n-k; --i) ans=ans*i%mod;
		return ans;
	}
	void solve() {
		f[0][0][0][0]=inv[0]=inv[1]=1;
		for (int i=2; i<=m+10; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
		for (int i=2; i<=m+10; ++i) inv[i]=inv[i-1]*inv[i]%mod;
		int lim1=N/d+1, lim2=N/d+1;
		for (int t=1; t<=m; ++t) {
			for (int i=0; i<=lim1; ++i)
				for (int j=0; j<=lim2; ++j)
					for (int k=0; k<=t; ++k) {
						int now=t&1;
						f[now][i][j][k]=0;
						if (i>=(a[t]+1)/d&&k) f[now][i][j][k]=(f[now][i][j][k]+f[now^1][i-(a[t]+1)/d][j][k-1]*-1*p[t])%mod;
						if (j>=(a[t]+1)/d) f[now][i][j][k]=(f[now][i][j][k]+f[now^1][i][j-(a[t]+1)/d][k]*-1*(1-p[t]))%mod;
						f[now][i][j][k]=(f[now][i][j][k]+f[now^1][i][j][k-1]*p[t])%mod;
						f[now][i][j][k]=(f[now][i][j][k]+f[now^1][i][j][k]*(1-p[t]))%mod;
						#ifdef DEBUG
						printf("f[%lld][%lld][%lld][%lld]=%lld\n", t, i, j, k, f[now][i][j][k]); //, cout<<f[now^1][i-(a[t]+1)/d][j][k-1]<<endl;
						#endif
					}
		}
		int now=m&1;
		for (int i=0; i<=lim1; ++i) {
			for (int j=0; j<=lim2; ++j) {
				for (int k=1; k<=m; ++k) {
					int n1=n-i*d, n2=N-n-j*d, tem=0;
					//if (n1<0||n2<0) continue;
					if (n1<0) {n2+=n1; if (n2<0) continue; tem=C(n2+m-1, m-1);}
					else {
						sta[top=1]=1;
						for (int t=n2+m-k; t>n2; --t) sta[1]=sta[1]*t%mod;
						for (int t=m-k+1; t<=m-1; ++t) ++top, sta[top]=sta[top-1]*(n2+t)%mod;
						ll c1=1;
						for (int t=1; t<=k; ++t) {
							if (t>1) c1=c1*(n1-1+t-1)%mod;
							//cout<<"my: "<<c1*inv[t-1]%mod<<' '<<sta[top]*inv[m-t]%mod<<endl;
				 			//cout<<"std: "<<C(n1-1+t-1, t-1)<<' '<<C(n2+m-t, m-t)<<endl;
							tem=(tem+c1*inv[t-1]%mod*sta[top--]%mod*inv[m-t])%mod;
						}
						//for (int t=1; t<=k; ++t) tem=(tem+C(n1-1+t-1, t-1)*C(n2+m-t, m-t))%mod;
					}
					if (n2<0) continue;
					if (f[now][i][j][k]) {
						#ifdef DEBUG
						printf("n1=%lld, n2=%lld, m1=%lld, tem=%lld\n", n1, n2, k, tem);
						printf("f[%lld][%lld][%lld][%lld]=%lld\n", m, i, j, k, f[now][i][j][k]);
						#endif
					}
					ans=(ans+f[now][i][j][k]*tem)%mod;
				}
			}
		}
		printf("%lld\n", (ans%mod+mod)%mod);
	}
}

signed main()
{
	freopen("key.in", "r", stdin);
	freopen("key.out", "w", stdout);

	m=read(); d=read(); N=read(); n=read();
	for (int i=1; i<=m; ++i) a[i]=read(), p[i]=read();
	// task1::solve();
	task::solve();
	
	return 0;
}
posted @ 2022-02-19 16:34  Administrator-09  阅读(3)  评论(0编辑  收藏  举报