题解 梦批糼

传送门

关于 \(O(n^6)\) 跑过了这件事

一个 \(O(n^5)\) 的做法是拆成 \(\binom{n}{2}\binom{m}{2}\) 个一维的情况考虑
正解咕了

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 65
#define ll long long
#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, m, k, w;
ll val[N][N][N];
bool vis[N][N][N];
const ll mod=998244353;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

namespace force{
	ll ans;
	int cnt[N][N][N];
	bool check(int a, int b, int c, int x, int y, int z) {
		for (int i=a; i<=x; ++i)
			for (int j=b; j<=y; ++j)
				for (int k=c; k<=z; ++k)
					if (!vis[i][j][k]) return 0;
		return 1;
	}
	void add(int a, int b, int c, int x, int y, int z) {
		for (int i=a; i<=x; ++i)
			for (int j=b; j<=y; ++j)
				for (int k=c; k<=z; ++k)
					++cnt[i][j][k];
	}
	void solve() {
		ll able=0;
		ll tot=1ll*n*(n+1)*m*(m+1)*k*(k+1)%mod*qpow(8, mod-2)%mod, inv_tot=qpow(tot, mod-2);
		for (int a=1; a<=n; ++a)
			for (int b=1; b<=m; ++b)
				for (int c=1; c<=k; ++c)
					for (int x=a; x<=n; ++x)
						for (int y=b; y<=m; ++y)
							for (int z=c; z<=k; ++z)
								if (check(a, b, c, x, y, z))
									add(a, b, c, x, y, z), ++able;
		for (int a=1; a<=n; ++a)
			for (int b=1; b<=m; ++b)
				for (int c=1; c<=k; ++c)
					if (vis[a][b][c])
						// ans=(ans+ val[a][b][c]*(1-qpow((1-cnt[a][b][c]*inv_tot)%mod, w)) )%mod; //, cout<<a<<' '<<b<<' '<<c<<' '<<val[a][b][c]<<endl;
						// ans=(ans+ val[a][b][c]*(qpow(able*inv_tot%mod, w)-qpow((tot-cnt[a][b][c])*inv_tot%mod, w)))%mod;
						ans=(ans+ val[a][b][c] * (qpow(able*inv_tot%mod, w)-qpow((able-cnt[a][b][c])*inv_tot%mod, w)) )%mod;
		cout<<(ans%mod+mod)%mod<<endl;
	}
}

// namespace guess_meaning_of_the_problem{
// 	ll ans;
// 	int cnt[N][N][N];
// 	bool check(int a, int b, int c, int x, int y, int z) {
// 		for (int i=a; i<=x; ++i)
// 			for (int j=b; j<=y; ++j)
// 				for (int k=c; k<=z; ++k)
// 					if (!vis[i][j][k]) return 0;
// 		return 1;
// 	}
// 	void add(int a, int b, int c, int x, int y, int z) {
// 		for (int i=a; i<=x; ++i)
// 			for (int j=b; j<=y; ++j)
// 				for (int k=c; k<=z; ++k)
// 					++cnt[i][j][k];
// 	}
// 	void solve() {
// 		ll able=0;
// 		ll tot=1ll*n*(n+1)*m*(m+1)*k*(k+1)%mod*qpow(8, mod-2)%mod, inv_tot=qpow(tot, mod-2);
// 		for (int a=1; a<=n; ++a)
// 			for (int b=1; b<=m; ++b)
// 				for (int c=1; c<=k; ++c)
// 					for (int x=a; x<=n; ++x)
// 						for (int y=b; y<=m; ++y)
// 							for (int z=c; z<=k; ++z)
// 								if (check(a, b, c, x, y, z))
// 									add(a, b, c, x, y, z), ++able;
// 		// cout<<tot<<' '<<able<<endl;
// 		for (int a=1; a<=n; ++a)
// 			for (int b=1; b<=m; ++b)
// 				for (int c=1; c<=k; ++c)
// 					if (vis[a][b][c])
// 						// ans=(ans+ val[a][b][c]*(1-qpow((1-cnt[a][b][c]*inv_tot)%mod, w)) )%mod; //, cout<<a<<' '<<b<<' '<<c<<' '<<val[a][b][c]<<endl;
// 						// ans=(ans+ val[a][b][c]*(qpow(able*inv_tot%mod, w)*(1-qpow((able-cnt[a][b][c])*qpow(able, mod-2)%mod, w)%mod)) )%mod;
// 						ans=(ans+ val[a][b][c] * (qpow(able*inv_tot%mod, w)-qpow((able-cnt[a][b][c])*inv_tot%mod, w)) )%mod;
// 		cout<<(ans%mod+mod)%mod<<endl;
// 	}
// }

// namespace enumerate_meaning_of_the_problem{
// 	ll ans, tot, inv_tot;
// 	int cnt[N][N][N];
// 	bool check(int a, int b, int c, int x, int y, int z) {
// 		for (int i=a; i<=x; ++i)
// 			for (int j=b; j<=y; ++j)
// 				for (int k=c; k<=z; ++k)
// 					if (!vis[i][j][k]) return 0;
// 		return 1;
// 	}
// 	void add(int a, int b, int c, int x, int y, int z, int dlt) {
// 		for (int i=a; i<=x; ++i)
// 			for (int j=b; j<=y; ++j)
// 				for (int k=c; k<=z; ++k)
// 					cnt[i][j][k]+=dlt;
// 	}
// 	void dfs(int now, ll pre) {
// 		if (now>w) {
// 			ll sum=0;
// 			for (int a=1; a<=n; ++a)
// 				for (int b=1; b<=m; ++b)
// 					for (int c=1; c<=k; ++c)
// 						if (vis[a][b][c] && cnt[a][b][c])
// 							sum=(sum+val[a][b][c])%mod;
// 			ans=(ans+sum*pre)%mod;
// 			return ;
// 		}
// 		for (int a=1; a<=n; ++a)
// 			for (int b=1; b<=m; ++b)
// 				for (int c=1; c<=k; ++c)
// 					for (int x=a; x<=n; ++x)
// 						for (int y=b; y<=m; ++y)
// 							for (int z=c; z<=k; ++z) {
// 								if (check(a, b, c, x, y, z)) add(a, b, c, x, y, z, 1);
// 								dfs(now+1, pre*inv_tot%mod);
// 								if (check(a, b, c, x, y, z)) add(a, b, c, x, y, z, -1);
// 							}
// 	}
// 	void solve() {
// 		tot=1ll*n*(n+1)*m*(m+1)*k*(k+1)%mod*qpow(8, mod-2)%mod, inv_tot=qpow(tot, mod-2);
// 		dfs(1, 1);
// 		cout<<(ans%mod+mod)%mod<<endl;
// 	}
// }

namespace task1{
	ll sum[N][N][N], ans;
	ll cnt[N][N][N]; int vis_cnt[N][N][N];
	// bool check(int a, int b, int c, int x, int y, int z) {
	// 	for (int i=a; i<=x; ++i)
	// 		for (int j=b; j<=y; ++j)
	// 			for (int k=c; k<=z; ++k)
	// 				if (!vis[i][j][k]) return 0;
	// 	return 1;
	// }
	inline void add(int a, int b, int c, int x, int y, int z) {
		// cout<<"add: "<<a<<' '<<b<<' '<<c<<' '<<x<<' '<<y<<' '<<z<<endl;
		++cnt[a][b][c];
		--cnt[a][y+1][c];
		--cnt[a][b][z+1];
		++cnt[a][y+1][z+1];
		--cnt[x+1][b][c];
		++cnt[x+1][b][z+1];
		++cnt[x+1][y+1][c];
		--cnt[x+1][y+1][z+1];
	}
	inline bool vis_query(int a, int b, int c, int x, int y, int z) {
		int cnt=0;
		cnt+=vis_cnt[x][y][z];
		cnt-=vis_cnt[x][b-1][z];
		cnt-=vis_cnt[x][y][c-1];
		cnt+=vis_cnt[x][b-1][c-1];
		cnt-=vis_cnt[a-1][y][z];
		cnt+=vis_cnt[a-1][y][c-1];
		cnt+=vis_cnt[a-1][b-1][z];
		cnt-=vis_cnt[a-1][b-1][c-1];
		return cnt==0;
	}
	void solve() {
		ll able=0;
		ll tot=1ll*n*(n+1)*m*(m+1)*k*(k+1)%mod*qpow(8, mod-2)%mod, inv_tot=qpow(tot, mod-2);
		for (int a=1; a<=n; ++a)
			for (int b=1; b<=m; ++b)
				for (int c=1; c<=k; ++c)
					if (!vis[a][b][c]) ++vis_cnt[a][b][c];
		for (int a=1; a<=n; ++a)
			for (int b=1; b<=m; ++b)
				for (int c=1; c<=k; ++c)
					vis_cnt[a][b][c]+=
						vis_cnt[a-1][b][c]+vis_cnt[a][b-1][c]+vis_cnt[a][b][c-1]
						-vis_cnt[a-1][b-1][c]-vis_cnt[a-1][b][c-1]-vis_cnt[a][b-1][c-1]
							+vis_cnt[a-1][b-1][c-1];
		for (int a=1; a<=n; ++a)
			for (int b=1; b<=m; ++b)
				for (int c=1; c<=k; ++c) if (vis[a][b][c])
					for (int x=a; x<=n; ++x)
						if (vis_query(a, b, c, x, b, c)) {
							for (int y=b; y<=m; ++y)
								if (vis_query(a, b, c, x, y, c)) {
									for (int z=c; z<=k; ++z)
										if (vis_query(a, b, c, x, y, z))
											add(a, b, c, x, y, z), ++able;
										else break;
								}
								else break;
						}
						else break;
		for (int a=1; a<=n; ++a)
			for (int b=1; b<=m; ++b)
				for (int c=1; c<=k; ++c)
					cnt[a][b][c]+=
						cnt[a-1][b][c]+cnt[a][b-1][c]+cnt[a][b][c-1]
						-cnt[a-1][b-1][c]-cnt[a-1][b][c-1]-cnt[a][b-1][c-1]
							+cnt[a-1][b-1][c-1];
		// cout<<"---cnt---"<<endl;
		// for (int a=1; a<=n; ++a)
		// 	for (int b=1; b<=m; ++b)
		// 		for (int c=1; c<=k; ++c)
		// 			printf("cnt[%lld][%lld][%lld]=%lld\n", a, b, c, cnt[a][b][c]);
		for (int a=1; a<=n; ++a)
			for (int b=1; b<=m; ++b)
				for (int c=1; c<=k; ++c)
					if (vis[a][b][c])
						ans=(ans+ val[a][b][c] * (qpow(able*inv_tot%mod, w)-qpow((able-cnt[a][b][c])*inv_tot%mod, w)) )%mod;
		cout<<(ans%mod+mod)%mod<<endl;
	}
}

signed main()
{
	freopen("dream.in", "r", stdin);
	freopen("dream.out", "w", stdout);

	n=read(); m=read(); k=read(); w=read();
	for (int i=1; i<=n; ++i)
		for (int j=1; j<=m; ++j)
			for (int l=1; l<=k; ++l)
				vis[i][j][l]=read();
	for (int i=1; i<=n; ++i)
		for (int j=1; j<=m; ++j)
			for (int l=1; l<=k; ++l)
				val[i][j][l]=read();
	// force::solve();
	task1::solve();

	return 0;
}
posted @ 2022-03-15 19:44  Administrator-09  阅读(1)  评论(0编辑  收藏  举报