【题解】[CSP-S2019] Emiya 家今天的饭
\(\text{Solution:}\)
又是一个经典题目……一直都不太会 肝了两天算是搞明白了
首先,观察到题目限制应该不难想到一个容斥。因为思考一下发现这两个限制同时满足难以表达在状态里,因为我们不能对每一道食材都记录它的出现次数。同时还有 至少 这个词汇。那么对谁容斥?
那应该就是对 不超过一半 的这个限制容斥了。现在让我们枚举一道食材,让它不满足这个限制,也就是说 强制出现次数大于 \(\frac{n}{2}\) 次。
那我们又可以发现,剩下的食材无论怎么选择都不会再出现第二个不合法的了!
那么最终的方案数得以轻松表示:用总方案数减掉每一道菜不合法的方案数。
现在的问题就转化为一个求不合法方案数的问题。
那么,设 \(f[i][j][k]\) 表示前 \(i\) 种方法,选择了 \(j\) 种,其中有 \(k\) 种是不合法食材做的菜的搭配方案数。
那么就会有一个简单的 \(dp:\)
其中,\(sum[i]\) 表示这一种方法所能做的菜的总数, \(a[i][pos]\) 表示用这种方法与 \(pos\) 食材所能做的菜的种类数。
于是上面的方程就分别对应了:不选,选择一个不合法食材,选择一个合法食材的三种转移。
初始值:\(f[i][0][0]=1,\) 转移的时候不要忘记转移 \(f[i][j][0]!!\)
#include<bits/stdc++.h>
using namespace std;
const int N=101;
int f[N][N][N],a[N][2001];
int sum[N],n,m;
const int mod=998244353;
inline int Add(int x,int y){return (x+y)%mod;}
inline int Mul(int x,int y){return 1ll*x*y%mod;}
inline int Max(int x,int y){return x>y?x:y;}
inline int Min(int x,int y){return x<y?x:y;}
inline int Dec(int x,int y){return (x-y+mod)%mod;}
int main(){
freopen("in.txt","r",stdin);
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i){
for(int j=1;j<=m;++j){
scanf("%d",&a[i][j]);
}
}
for(int i=1;i<=n;++i){
for(int j=1;j<=m;++j){
sum[i]=Add(sum[i],a[i][j]);
}
}
int All=1;
for(int i=1;i<=n;++i)All=Mul(All,sum[i]+1);
for(int pos=1;pos<=m;++pos){
memset(f,0,sizeof f);
f[0][0][0]=1;
for(int i=1;i<=n;++i){
f[i][0][0]=1;
for(int j=1;j<=n;++j){
f[i][j][0]=Add(f[i-1][j][0],Mul(f[i-1][j-1][0],sum[i]-a[i][pos]));////f[i][j][0]!
for(int k=1;k<=n;++k){
f[i][j][k]=Add(f[i][j][k],f[i-1][j][k]);
f[i][j][k]=Add(f[i][j][k],Mul(f[i-1][j-1][k-1],a[i][pos]));
f[i][j][k]=Add(f[i][j][k],Mul(f[i-1][j-1][k],sum[i]-a[i][pos]));
}
}
}
for(int i=0;i<=n;++i){
for(int j=(i/2)+1;j<=n;++j)
All=Dec(All,f[n][i][j]);
}
}
printf("%d\n",All-1);
return 0;
}
这样就有了 \(84pts\)
继续考虑,把 \(O(n^3m)\) 优化到 \(O(n^2m)\) 就完事了。
观察一下,我们能不能把 选择多少道菜 和 选了多少道不合法的菜 合并到一起?
仔细想想,如果我们知道在一种选择方案下,选择的 \(pos\) 菜品的数量比其他的更多,那么我们实际上无需知道一共选择了多少个,只需要知道这种情况的方案数然后减掉即可。
那么我们就可以考虑换一个状态:设 \(f[i][j]\) 表示前 \(i\) 种食材,选择了 \(pos\) 食材的食品数量与其他的食材食品数量之差为 \(j\) 的方案数。
注意这里的 \(j\) 可以是负数,所以要整体平移。
那么就可以同样地写出方程:
初始值也是一样的,而这里注意一下减的时候要从 \(pos\) 数目严格大于的时候开始算。
如果 \(j=0\) 实际是合法方案数。
这样少掉一层循环,复杂度就是 \(O(n^2m)\) 了。
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=501;
int f[N][N],a[N][5001];
int sum[N],n,m;
const int mod=998244353;
inline int Add(int x,int y){return (x+y)%mod;}
inline int Mul(int x,int y){return 1ll*x*y%mod;}
inline int Max(int x,int y){return x>y?x:y;}
inline int Min(int x,int y){return x<y?x:y;}
inline int Dec(int x,int y){return (x-y+mod)%mod;}
inline int getpos(int x){return (x+n+1);}
void print(int P){
printf("%d:\n",P);
for(int i=0;i<=n;++i)
for(int j=-n;j<=n;++j)
printf("%d%c",f[i][getpos(j)],j==n?'\n':' ');
puts("");
}
signed main(){
freopen("meal.in","r",stdin);
freopen("meal.out","w",stdout);
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i){
for(int j=1;j<=m;++j){
scanf("%d",&a[i][j]);
}
}
for(int i=1;i<=n;++i){
for(int j=1;j<=m;++j){
sum[i]=Add(sum[i],a[i][j]);
}
}
int All=1;
for(int i=1;i<=n;++i)All=Mul(All,sum[i]+1);
for(int pos=1;pos<=m;++pos){
memset(f,0,sizeof f);
f[0][n+1]=1;
for(int i=1;i<=n;++i){
for(int j=-n;j<=n;++j){
int poss=getpos(j);
f[i][poss]=Add(f[i-1][poss],f[i][poss]);
f[i][poss]=Add(f[i][poss],Mul(f[i-1][poss-1],a[i][pos]));
f[i][poss]=Add(f[i][poss],Mul(f[i-1][poss+1],(sum[i]-a[i][pos])));
}
}
for(int i=1;i<=n;++i)All=Dec(All,f[n][getpos(i)]);
}
printf("%lld\n",All-1);
return 0;
}