题解 梦批糼
关于 \(O(n^6)\) 跑过了这件事
一个 \(O(n^5)\) 的做法是拆成 \(\binom{n}{2}\binom{m}{2}\) 个一维的情况考虑
正解咕了
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 65
#define ll long long
#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, m, k, w;
ll val[N][N][N];
bool vis[N][N][N];
const ll mod=998244353;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
namespace force{
ll ans;
int cnt[N][N][N];
bool check(int a, int b, int c, int x, int y, int z) {
for (int i=a; i<=x; ++i)
for (int j=b; j<=y; ++j)
for (int k=c; k<=z; ++k)
if (!vis[i][j][k]) return 0;
return 1;
}
void add(int a, int b, int c, int x, int y, int z) {
for (int i=a; i<=x; ++i)
for (int j=b; j<=y; ++j)
for (int k=c; k<=z; ++k)
++cnt[i][j][k];
}
void solve() {
ll able=0;
ll tot=1ll*n*(n+1)*m*(m+1)*k*(k+1)%mod*qpow(8, mod-2)%mod, inv_tot=qpow(tot, mod-2);
for (int a=1; a<=n; ++a)
for (int b=1; b<=m; ++b)
for (int c=1; c<=k; ++c)
for (int x=a; x<=n; ++x)
for (int y=b; y<=m; ++y)
for (int z=c; z<=k; ++z)
if (check(a, b, c, x, y, z))
add(a, b, c, x, y, z), ++able;
for (int a=1; a<=n; ++a)
for (int b=1; b<=m; ++b)
for (int c=1; c<=k; ++c)
if (vis[a][b][c])
// ans=(ans+ val[a][b][c]*(1-qpow((1-cnt[a][b][c]*inv_tot)%mod, w)) )%mod; //, cout<<a<<' '<<b<<' '<<c<<' '<<val[a][b][c]<<endl;
// ans=(ans+ val[a][b][c]*(qpow(able*inv_tot%mod, w)-qpow((tot-cnt[a][b][c])*inv_tot%mod, w)))%mod;
ans=(ans+ val[a][b][c] * (qpow(able*inv_tot%mod, w)-qpow((able-cnt[a][b][c])*inv_tot%mod, w)) )%mod;
cout<<(ans%mod+mod)%mod<<endl;
}
}
// namespace guess_meaning_of_the_problem{
// ll ans;
// int cnt[N][N][N];
// bool check(int a, int b, int c, int x, int y, int z) {
// for (int i=a; i<=x; ++i)
// for (int j=b; j<=y; ++j)
// for (int k=c; k<=z; ++k)
// if (!vis[i][j][k]) return 0;
// return 1;
// }
// void add(int a, int b, int c, int x, int y, int z) {
// for (int i=a; i<=x; ++i)
// for (int j=b; j<=y; ++j)
// for (int k=c; k<=z; ++k)
// ++cnt[i][j][k];
// }
// void solve() {
// ll able=0;
// ll tot=1ll*n*(n+1)*m*(m+1)*k*(k+1)%mod*qpow(8, mod-2)%mod, inv_tot=qpow(tot, mod-2);
// for (int a=1; a<=n; ++a)
// for (int b=1; b<=m; ++b)
// for (int c=1; c<=k; ++c)
// for (int x=a; x<=n; ++x)
// for (int y=b; y<=m; ++y)
// for (int z=c; z<=k; ++z)
// if (check(a, b, c, x, y, z))
// add(a, b, c, x, y, z), ++able;
// // cout<<tot<<' '<<able<<endl;
// for (int a=1; a<=n; ++a)
// for (int b=1; b<=m; ++b)
// for (int c=1; c<=k; ++c)
// if (vis[a][b][c])
// // ans=(ans+ val[a][b][c]*(1-qpow((1-cnt[a][b][c]*inv_tot)%mod, w)) )%mod; //, cout<<a<<' '<<b<<' '<<c<<' '<<val[a][b][c]<<endl;
// // ans=(ans+ val[a][b][c]*(qpow(able*inv_tot%mod, w)*(1-qpow((able-cnt[a][b][c])*qpow(able, mod-2)%mod, w)%mod)) )%mod;
// ans=(ans+ val[a][b][c] * (qpow(able*inv_tot%mod, w)-qpow((able-cnt[a][b][c])*inv_tot%mod, w)) )%mod;
// cout<<(ans%mod+mod)%mod<<endl;
// }
// }
// namespace enumerate_meaning_of_the_problem{
// ll ans, tot, inv_tot;
// int cnt[N][N][N];
// bool check(int a, int b, int c, int x, int y, int z) {
// for (int i=a; i<=x; ++i)
// for (int j=b; j<=y; ++j)
// for (int k=c; k<=z; ++k)
// if (!vis[i][j][k]) return 0;
// return 1;
// }
// void add(int a, int b, int c, int x, int y, int z, int dlt) {
// for (int i=a; i<=x; ++i)
// for (int j=b; j<=y; ++j)
// for (int k=c; k<=z; ++k)
// cnt[i][j][k]+=dlt;
// }
// void dfs(int now, ll pre) {
// if (now>w) {
// ll sum=0;
// for (int a=1; a<=n; ++a)
// for (int b=1; b<=m; ++b)
// for (int c=1; c<=k; ++c)
// if (vis[a][b][c] && cnt[a][b][c])
// sum=(sum+val[a][b][c])%mod;
// ans=(ans+sum*pre)%mod;
// return ;
// }
// for (int a=1; a<=n; ++a)
// for (int b=1; b<=m; ++b)
// for (int c=1; c<=k; ++c)
// for (int x=a; x<=n; ++x)
// for (int y=b; y<=m; ++y)
// for (int z=c; z<=k; ++z) {
// if (check(a, b, c, x, y, z)) add(a, b, c, x, y, z, 1);
// dfs(now+1, pre*inv_tot%mod);
// if (check(a, b, c, x, y, z)) add(a, b, c, x, y, z, -1);
// }
// }
// void solve() {
// tot=1ll*n*(n+1)*m*(m+1)*k*(k+1)%mod*qpow(8, mod-2)%mod, inv_tot=qpow(tot, mod-2);
// dfs(1, 1);
// cout<<(ans%mod+mod)%mod<<endl;
// }
// }
namespace task1{
ll sum[N][N][N], ans;
ll cnt[N][N][N]; int vis_cnt[N][N][N];
// bool check(int a, int b, int c, int x, int y, int z) {
// for (int i=a; i<=x; ++i)
// for (int j=b; j<=y; ++j)
// for (int k=c; k<=z; ++k)
// if (!vis[i][j][k]) return 0;
// return 1;
// }
inline void add(int a, int b, int c, int x, int y, int z) {
// cout<<"add: "<<a<<' '<<b<<' '<<c<<' '<<x<<' '<<y<<' '<<z<<endl;
++cnt[a][b][c];
--cnt[a][y+1][c];
--cnt[a][b][z+1];
++cnt[a][y+1][z+1];
--cnt[x+1][b][c];
++cnt[x+1][b][z+1];
++cnt[x+1][y+1][c];
--cnt[x+1][y+1][z+1];
}
inline bool vis_query(int a, int b, int c, int x, int y, int z) {
int cnt=0;
cnt+=vis_cnt[x][y][z];
cnt-=vis_cnt[x][b-1][z];
cnt-=vis_cnt[x][y][c-1];
cnt+=vis_cnt[x][b-1][c-1];
cnt-=vis_cnt[a-1][y][z];
cnt+=vis_cnt[a-1][y][c-1];
cnt+=vis_cnt[a-1][b-1][z];
cnt-=vis_cnt[a-1][b-1][c-1];
return cnt==0;
}
void solve() {
ll able=0;
ll tot=1ll*n*(n+1)*m*(m+1)*k*(k+1)%mod*qpow(8, mod-2)%mod, inv_tot=qpow(tot, mod-2);
for (int a=1; a<=n; ++a)
for (int b=1; b<=m; ++b)
for (int c=1; c<=k; ++c)
if (!vis[a][b][c]) ++vis_cnt[a][b][c];
for (int a=1; a<=n; ++a)
for (int b=1; b<=m; ++b)
for (int c=1; c<=k; ++c)
vis_cnt[a][b][c]+=
vis_cnt[a-1][b][c]+vis_cnt[a][b-1][c]+vis_cnt[a][b][c-1]
-vis_cnt[a-1][b-1][c]-vis_cnt[a-1][b][c-1]-vis_cnt[a][b-1][c-1]
+vis_cnt[a-1][b-1][c-1];
for (int a=1; a<=n; ++a)
for (int b=1; b<=m; ++b)
for (int c=1; c<=k; ++c) if (vis[a][b][c])
for (int x=a; x<=n; ++x)
if (vis_query(a, b, c, x, b, c)) {
for (int y=b; y<=m; ++y)
if (vis_query(a, b, c, x, y, c)) {
for (int z=c; z<=k; ++z)
if (vis_query(a, b, c, x, y, z))
add(a, b, c, x, y, z), ++able;
else break;
}
else break;
}
else break;
for (int a=1; a<=n; ++a)
for (int b=1; b<=m; ++b)
for (int c=1; c<=k; ++c)
cnt[a][b][c]+=
cnt[a-1][b][c]+cnt[a][b-1][c]+cnt[a][b][c-1]
-cnt[a-1][b-1][c]-cnt[a-1][b][c-1]-cnt[a][b-1][c-1]
+cnt[a-1][b-1][c-1];
// cout<<"---cnt---"<<endl;
// for (int a=1; a<=n; ++a)
// for (int b=1; b<=m; ++b)
// for (int c=1; c<=k; ++c)
// printf("cnt[%lld][%lld][%lld]=%lld\n", a, b, c, cnt[a][b][c]);
for (int a=1; a<=n; ++a)
for (int b=1; b<=m; ++b)
for (int c=1; c<=k; ++c)
if (vis[a][b][c])
ans=(ans+ val[a][b][c] * (qpow(able*inv_tot%mod, w)-qpow((able-cnt[a][b][c])*inv_tot%mod, w)) )%mod;
cout<<(ans%mod+mod)%mod<<endl;
}
}
signed main()
{
freopen("dream.in", "r", stdin);
freopen("dream.out", "w", stdout);
n=read(); m=read(); k=read(); w=read();
for (int i=1; i<=n; ++i)
for (int j=1; j<=m; ++j)
for (int l=1; l<=k; ++l)
vis[i][j][l]=read();
for (int i=1; i<=n; ++i)
for (int j=1; j<=m; ++j)
for (int l=1; l<=k; ++l)
val[i][j][l]=read();
// force::solve();
task1::solve();
return 0;
}