题目链接
https://atcoder.jp/contests/agc039/tasks/agc039_f
题解
又是很简单的F题我不会。。。
考虑先给每行每列钦定一个最小值\(a_i,b_j\),并假设每行每列的最小值是这个数,且每行每列只需要放\(\ge\)这个数的数即可,那么这种情况的价值是\(\prod^n_{i=1}\prod^m_{j=1}\min(a_i,b_j)\), 方案数是\(\prod^n_{i=1}\prod^m_{j=1}(n+1-\max(a_i,b_j))\)
然后我们需要把最小值的限制容斥掉,也就是枚举若干行若干列容斥掉(限制\(+1\)同时系数乘以\(-1\))。
这样的话直接暴力DP就可以解决。设\(f[k][i][j]\)表示当前用\([1,k]\)中的数填满了\(i\)行\(j\)列。转移可以直接枚举不被容斥的行数、不被容斥的列数、容斥的行数、容斥的列数,乘上贡献系数,得到了一个多项式时间复杂度的算法。
但是我们发现这样转移显然很浪费,我们可以把四个变量同时枚举改成分四个阶段依次枚举,这样转移时间复杂度降到了\(O(n)\).(注意因为要保证从小到大填数,所以必须先枚举不被容斥再枚举被容斥)
不过这题还挺卡常的……需要\(O(n^3)\)预处理一下转移系数,详见代码
时间复杂度\(O(n^4)\)
orz myh
代码
#include<bits/stdc++.h>
#define llong long long
#define mkpr make_pair
#define riterator reverse_iterator
using namespace std;
inline int read()
{
int x = 0,f = 1; char ch = getchar();
for(;!isdigit(ch);ch=getchar()) {if(ch=='-') f = -1;}
for(; isdigit(ch);ch=getchar()) {x = x*10+ch-48;}
return x*f;
}
const int N = 100;
int P;
llong pw[N+3][N*N+3];
llong comb[N+3][N+3];
llong f[2][N+3][N+3];
llong trans[N+3][N+3];
int n,m,p;
llong quickpow(llong x,llong y)
{
llong cur = x,ret = 1ll;
for(int i=0; y; i++)
{
if(y&(1ll<<i)) {y-=(1ll<<i); ret = ret*cur%P;}
cur = cur*cur%P;
}
return ret;
}
void initmath()
{
for(int i=0; i<=N; i++)
{
pw[i][0] = 1ll; for(int j=1; j<=N*N; j++) pw[i][j] = pw[i][j-1]*i%P;
}
comb[0][0] = 1ll;
for(int i=1; i<=N; i++)
{
comb[i][0] = comb[i][i] = 1ll;
for(int j=1; j<i; j++) comb[i][j] = (comb[i-1][j]+comb[i-1][j-1])%P;
}
}
llong updsum(llong &x,llong y) {x = x+y>=P?x+y-P:x+y;}
int main()
{
scanf("%d%d%d%lld",&n,&m,&p,&P);
initmath();
int curk = 0; f[0][0][0] = 1ll;
for(int k=1; k<=p; k++)
{
curk^=1; memset(f[curk],0,sizeof(f[curk]));
for(int j=0; j<=m; j++) for(int ii=0; ii<=n; ii++) trans[j][ii] = pw[k][ii*(m-j)]%P*pw[p-k+1][ii*j]%P;
for(int i=0; i<=n; i++)
{
for(int j=0; j<=m; j++)
{
llong x = f[curk^1][i][j]; if(!x) continue;
for(int ii=0; ii+i<=n; ii++)
{
updsum(f[curk][i+ii][j],x*comb[i+ii][i]%P*trans[j][ii]%P);
}
}
}
curk^=1; memset(f[curk],0,sizeof(f[curk]));
for(int i=0; i<=n; i++) for(int jj=0; jj<=m; jj++) trans[i][jj] = pw[k][jj*(n-i)]%P*pw[p-k+1][jj*i]%P;
for(int i=0; i<=n; i++)
{
for(int j=0; j<=m; j++)
{
llong x = f[curk^1][i][j]; if(!x) continue;
for(int jj=0; jj+j<=m; jj++)
{
updsum(f[curk][i][j+jj],x*comb[j+jj][j]%P*trans[i][jj]%P);
}
}
}
curk^=1; memset(f[curk],0,sizeof(f[curk]));
for(int j=0; j<=m; j++) for(int ii=0; ii<=n; ii++) trans[j][ii] = pw[k][ii*(m-j)]%P*pw[p-k][ii*j]%P;
for(int i=0; i<=n; i++)
{
for(int j=0; j<=m; j++)
{
llong x = f[curk^1][i][j]; if(!x) continue;
for(int ii=0; ii+i<=n; ii++)
{
llong y = x*comb[i+ii][i]%P*trans[j][ii]%P;
updsum(f[curk][i+ii][j],ii&1?P-y:y);
}
}
}
curk^=1; memset(f[curk],0,sizeof(f[curk]));
for(int i=0; i<=n; i++) for(int jj=0; jj<=m; jj++) trans[i][jj] = pw[k][jj*(n-i)]%P*pw[p-k][i*jj]%P;
for(int i=0; i<=n; i++)
{
for(int j=0; j<=m; j++)
{
llong x = f[curk^1][i][j]; if(!x) continue;
for(int jj=0; jj+j<=m; jj++)
{
llong y = x*comb[j+jj][j]%P*trans[i][jj]%P;
updsum(f[curk][i][j+jj],jj&1?P-y:y);
}
}
}
}
printf("%lld\n",f[curk][n][m]);
return 0;
}