P5664 Emiya 家今天的饭
题面
前言
去年把我做自闭的一道题,看了一眼题面,发现只有 t1 有点思路,结果写到一半发现自己读错题了,又只能花时间来重构,结果后面的暴力一点都没写(主要是自己当时不会)
然后,这道题还因为某种原因爆玲了,因此我就成了全机房最菜的人。
题解
这道题题面还是很长的,所以我们简化一下题意。
给你一个 n*m 的矩阵,要求你从每一行选一个数,这一行可以选也可以不选,但最后至少选一个,且选的最多的那一列不能超过选的总数的 \(1 \over 2\)
part 1 24 - 32 分
直接爆搜出结果,加上一些剪枝可以拿到一部分了。
part 2 84 分
容斥加 dp。
没有第三个限制我们其实很好求,但带上第三个却有些麻烦。
我们考虑容斥一下,用总的方案数减去不合法的方案数,就是最后答案。
总的方案数就是 \((\displaystyle\prod_{i=1}^{n}\sum_{j=1}^{m} a[i][j]+1)\) -1
解释一下,每一行可以分开来考虑,乘法计数原理,对于这一行可以用加法计数原理,也就是这一行所有的树相加在加一,加一是因为要算上这一行不选的情况。
最后在减一,除去所有行都不选的情况。
对于不合法的方案数,可以考虑是哪一行选多了不合法,枚举每一行不合法的方案数,最后再总和就是不合法的方案数。
我们可以考虑用 dp 来解决这个问题。
设 \(f[i][j][k]\) 表示前 \(i\) 行,选了 \(j\) 列,且现在枚举的这一列选了 \(j\) 个的方案数。
转移就是 f[i][j][k] = f[i-1][j][k] (不选的时候) + f[i-1][j-1][k-1] * a[i][u] (选这一列的时候) + f[i-1][j-1][k] * (sum[i]-a[i][u])(选其他列的时候)
最后在减去不合法的方案数就是最后答案。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
#define int long long
const int p = 998244353;
int n,m,tot,a[110][2010],sum[110],f[110][110][110];
inline int read()
{
int s = 0, w = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
while(ch >= '0' && ch <= '9'){s =s * 10+ch - '0'; ch = getchar();}
return s * w;
}
void calc(int id)//计算不合法的情况
{
memset(f,0,sizeof(f));
f[0][0][0] = 1;
for(int i = 1; i <= n; i++)
{
f[i][0][0] = 1;
for(int j = 1; j <= i; j++)
{
for(int k = 0; k <= j; k++)
{
f[i][j][k] = (f[i-1][j][k] + f[i-1][j-1][k-1] * a[i][id] % p) % p;
f[i][j][k] = (f[i][j][k] + (f[i-1][j-1][k] * (sum[i]-a[i][id]) % p)) % p;
}
}
}
// for(int i = 1; i <= n; i++) for(int j = 1; j <= i; j++) cout<<f[n][i][j]<<endl;
}
signed main()
{
n = read(); m = read(); tot = 1;
for(int i = 1; i <= n; i++)
{
for(int j = 1; j <= m; j++)
{
a[i][j] = read();
sum[i] = (sum[i] + a[i][j])%p;
}
tot = (tot * (sum[i]+1))%p;
}
tot -= 1;//总方案数
for(int i = 1; i <= m; i++)
{
calc(i);
for(int j = 1; j <= n; j++)
{
for(int k = j/2+1; k <= j; k++)//枚举选了多少行,以及不合法的情况
{
tot = (tot - f[n][j][k])%p;
}
}
}
printf("%lld\n",(tot%p+p)%p);
return 0;
}
part 3 100 分
你会发现上面会跑的很慢,因为他的复杂度是 O(\(nm^3\)) 的,我们只能想办法优化掉一个 \(m\)
然后,这就是本题最关键也是最巧妙的地方,我们不用管心 \(j\) 与 \(k\) 到底具体选了多少个,
而是关心 \(k > {j \over 2}\) 即 \(2 \times k < j\) ,所以我们可以把第二维和第三维合并成一维,用 \(2 \times k -j\) 的差值来表示(你也可以用 \(k - j\) 来表示)。
转移就是 f[i][j] = f[i-1][j](不选的话差值不变) + f[i-1][j-1] * a[i][k] (选这一列的时候) + f[i-1][j+1] * (sim[i]-a[i][k]) (选其他列的情况)
由于他可能会出现负数,所以我们整体平移 \(n\) 表示 \(2 \times k - j + n\) 的差值就避免了 RE 的问题。
Code
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
#define int long long
const int p = 998244353;
int n,m,tot,a[110][2010],sum[110],f[110][220];
inline int read()
{
int s = 0, w = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
while(ch >= '0' && ch <= '9'){s =s * 10+ch - '0'; ch = getchar();}
return s * w;
}
void calc(int id)
{
memset(f,0,sizeof(f));
f[0][n] = 1;
for(int i = 1; i <= n; i++)
{
f[i][n] = 1;
for(int j = 0; j <= 2 * n; j++)
{
f[i][j] = (f[i-1][j] + f[i-1][j-1] * a[i][id] % p) % p;
f[i][j] = (f[i][j] + (f[i-1][j+1] * (sum[i]-a[i][id]) % p)) % p;
}
}
// for(int i = 1; i <= n; i++) for(int j = 1; j <= i; j++) cout<<f[n][i][j]<<endl;
}
signed main()
{
n = read(); m = read(); tot = 1;
for(int i = 1; i <= n; i++)
{
for(int j = 1; j <= m; j++)
{
a[i][j] = read();
sum[i] = (sum[i] + a[i][j])%p;
}
tot = (tot * (sum[i]+1))%p;
}
tot -= 1;
for(int i = 1; i <= m; i++)
{
calc(i);
for(int j = 1; j <= n; j++)
{
tot = (tot - f[n][n+j])%p;
}
}
printf("%lld\n",(tot%p+p)%p);
return 0;
}