题解 过山车

传送门

部分分可以插头 DP

考虑 \(w_{i, j}=0\) 就是判断是否存在合法方案

  • 网格图上的问题都试着黑白染色一下

发现黑白染色后因为每个点度数为 2,就转化为了二分图匹配,网络流即可

正解可以扔链接嘛……

  • 凸费用流的建边
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 17000
#define fir first
#define sec second
#define pb push_back
#define ll long long
// #define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, m;
int mp[155][35], w[155][35];
// unordered_map<int, int> mp2;
// int bit[N], sta[2][N], cnt[2], val[2][N], now;
// inline void chkmax(int& a, int b) {a=max(a, b);}
// struct hash_map{
// 	static const int SIZE=1000000;
// 	int head[SIZE], rub[N], ecnt, rtop;
// 	struct edge{int val, next, dat;}e[N];
// 	hash_map(){memset(head, -1, sizeof(head));}
// 	inline int end() {return -1;}
// 	inline int find(int t) {
// 		int h=t*13131%SIZE;
// 		for (int i=head[h]; ~i; i=e[i].next)
// 			if (e[i].val==t) return 1;
// 		return -1;
// 	}
// 	inline int& operator [] (int t) {
// 		int h=t*13131%SIZE;
// 		for (int i=head[h]; ~i; i=e[i].next)
// 			if (e[i].val==t) return e[i].dat;
// 		e[++ecnt]={t, head[h], 0}; head[h]=ecnt;
// 		rub[++rtop]=h;
// 		return e[ecnt].dat;
// 	}
// 	inline void clear() {
// 		ecnt=0;
// 		while (rtop) head[rub[rtop--]]=-1;
// 	}
// }mp2;

// namespace force{
// 	int type[10], a[N], tab[155][35], ans=-1;
// 	const int dlt[][2]={{0,1},{1,0},{0,-1},{-1,0}};
// 	vector<pair<int, int>> buc;
// 	void dfs(int u) {
// 		if (u==buc.size()) {
// 			int sum=0;
// 			for (int i=0; i<buc.size(); ++i) {
// 				int x=buc[i].fir, y=buc[i].sec;
// 				for (int j=0; j<4; ++j) if (type[a[i]]&(1<<j)) {
// 					int tx=x+dlt[j][0], ty=y+dlt[j][1];
// 					if (!(type[tab[tx][ty]]&(1<<((j+2)%4)))) return ;
// 				}
// 			}
// 			for (int i=0; i<buc.size(); ++i) if (a[i]>2) sum+=w[buc[i].fir][buc[i].sec];
// 			// cout<<"sum: "<<sum<<endl;
// 			// cout<<"a: "; for (int i=0; i<buc.size(); ++i) cout<<a[i]<<' '; cout<<endl;
// 			ans=max(ans, sum);
// 			return ;
// 		}
// 		for (int i=1; i<=6; ++i) {
// 			a[u]=i; tab[buc[u].fir][buc[u].sec]=i;
// 			dfs(u+1);
// 		}
// 	}
// 	void solve() {
// 		for (int i=1; i<=n; ++i) for (int j=1; j<=m; ++j) if (mp[i][j]) buc.pb({i, j});
// 		type[1] |= (1<<3) | (1<<1);
// 		type[2] |= (1<<0) | (1<<2);
// 		type[3] |= (1<<3) | (1<<0);
// 		type[4] |= (1<<3) | (1<<2);
// 		type[5] |= (1<<0) | (1<<1);
// 		type[6] |= (1<<2) | (1<<1);
// 		dfs(0);
// 		cout<<ans<<endl;
// 	}
// }

// namespace task1{
// 	void ins(int s, int t) {
// 		// cout<<"ins: "<<endl;
// 		if (cnt[now^1]>N-5) cerr<<"error"<<endl;
// 		if (mp2.find(s)==mp2.end()) sta[now^1][mp2[s]=++cnt[now^1]]=s, val[now^1][cnt[now^1]]=t;
// 		else chkmax(val[now^1][mp2[s]], t);
// 	}
// 	int ql(int s, int pos) {
// 		// cout<<"ql: "<<endl;
// 		int cnt=1, t;
// 		for (--pos; ; --pos) {
// 			t=(s>>bit[pos])&3;
// 			if (t==1) --cnt;
// 			else if (t==2) ++cnt;
// 			if (!cnt) return pos;
// 		}
// 		assert(0);
// 		return 0;
// 	}
// 	int qr(int s, int pos) {
// 		// cout<<"qr: "<<endl;
// 		int cnt=1, t;
// 		for (++pos; ; ++pos) {
// 			t=(s>>bit[pos])&3;
// 			if (t==1) ++cnt;
// 			else if (t==2) --cnt;
// 			if (!cnt) return pos;
// 		}
// 		assert(0);
// 		return 0;
// 	}
// 	void solve() {
// 		cnt[now]=1; sta[now][1]=0; val[now][1]=0;
// 		for (int i=1; i<=m; ++i) bit[i]=i<<1;
// 		for (int i=1; i<=n; ++i) {
// 			for (int j=1; j<=cnt[now]; ++j) sta[now][j]<<=2;
// 			for (int j=1; j<=m; ++j,now^=1) {
// 				// cout<<"ij: "<<i<<' '<<j<<endl;
// 				// cout<<"---sta---"<<endl;
// 				// for (int k=1; k<=cnt[now]; ++k) {
// 				// 	for (int l=0; l<=m; ++l) cout<<((sta[now][k]>>bit[l])&3); cout<<' '<<val[now][k]<<endl;
// 				// }
// 				// cout<<"---end---"<<endl;
// 				cnt[now^1]=0; mp2.clear();
// 				for (int k=1,pos; k<=cnt[now]; ++k) {
// 					// cout<<"k: "<<k<<endl;
// 					// for (int l=0; l<=m; ++l) cout<<((sta[now][k]>>bit[l])&3); cout<<' '<<val[now][k]<<endl;
// 					int l=(sta[now][k]>>bit[j-1])&3, u=(sta[now][k]>>bit[j])&3, s=sta[now][k], t=val[now][k]+w[i][j];
// 					if (!mp[i][j]) {
// 						if (!l&&!u) ins(s, val[now][k]);
// 					}
// 					else if (!l&&!u) {
// 						if (mp[i+1][j]&&mp[i][j+1]) ins(s+(1ll<<bit[j-1])+(2ll<<bit[j]), t);
// 					}
// 					else if (l&&!u) {
// 						if (mp[i+1][j]) ins(s, t);
// 						if (mp[i][j+1]) ins(s-(l<<bit[j-1])+(l<<bit[j]), val[now][k]);
// 					}
// 					else if (!l&&u) {
// 						if (mp[i][j+1]) ins(s, t);
// 						if (mp[i+1][j]) ins(s-(u<<bit[j])+(u<<bit[j-1]), val[now][k]);
// 					}
// 					else {
// 						if (l==1&&u==1) pos=qr(s, j), ins(s-(l<<bit[j-1])-(u<<bit[j])+((-2ll+1ll)<<bit[pos]), t);
// 						if (l==1&&u==2) ins(s-(l<<bit[j-1])-(u<<bit[j]), t);
// 						if (l==2&&u==1) ins(s-(l<<bit[j-1])-(u<<bit[j]), t);
// 						if (l==2&&u==2) pos=ql(s, j-1), ins(s-(l<<bit[j-1])-(u<<bit[j])+((-1ll+2ll)<<bit[pos]), t);
// 					}
// 				}
// 				// cout<<"out"<<endl;
// 			}
// 		}
// 		if (!cnt[now]) puts("-1");
// 		else {
// 			int ans=0;
// 			for (int i=1; i<=cnt[now]; ++i) ans=max(ans, val[now][i]);
// 			printf("%lld\n", ans);
// 		}
// 	}
// }

namespace task{
	bool vis[N];
	int dis[N], back[N], inc[N];
	int head[N], id[155][35][3], ecnt=1, s, t, tot, cnt, sum, ans1, ans2;
	const int dlt[][2]={{1,0},{-1,0},{0,1},{0,-1}};
	struct edge{int to, next, flw, cst;}e[N*5];
	inline void add(int s, int t, int f, int c) {e[++ecnt]={t, head[s], f, c}; head[s]=ecnt;}
	bool spfa(int s, int t) {
		// cout<<"spfa: "<<s<<' '<<t<<endl;
		memset(dis, 0x3f, sizeof(dis));
		memset(back, -1, sizeof(back));
		queue<int> q;
		dis[s]=0; inc[s]=INF;
		q.push(s);
		while (q.size()) {
			int u=q.front(); q.pop();
			vis[u]=0;
			for (int i=head[u],v; ~i; i=e[i].next) {
				v = e[i].to;
				if (e[i].flw && dis[u]+e[i].cst<dis[v]) {
					dis[v]=dis[u]+e[i].cst;
					back[v]=i; inc[v]=min(inc[u], e[i].flw);
					if (!vis[v]) q.push(v), vis[v]=1;
				}
			}
		}
		return ~back[t];
	}
	void solve() {
		s=++tot; t=++tot;
		memset(head, -1, sizeof(head));
		for (int i=1; i<=n; ++i) for (int j=1; j<=m; ++j) for (int k=0; k<3; ++k) id[i][j][k]=++tot;
		for (int i=1; i<=n; ++i) {
			for (int j=1; j<=m; ++j) if (mp[i][j]) {
				++cnt; sum+=w[i][j];
				if ((i+j)&1) {
					add(s, id[i][j][0], 2, 0), add(id[i][j][0], s, 0, 0);
					add(id[i][j][0], id[i][j][1], 1, w[i][j]), add(id[i][j][1], id[i][j][0], 0, -w[i][j]);
					add(id[i][j][0], id[i][j][1], 1, 0), add(id[i][j][1], id[i][j][0], 0, 0);
					add(id[i][j][0], id[i][j][2], 1, w[i][j]), add(id[i][j][2], id[i][j][0], 0, -w[i][j]);
					add(id[i][j][0], id[i][j][2], 1, 0), add(id[i][j][2], id[i][j][0], 0, 0);
				}
				else {
					add(id[i][j][0], t, 2, 0), add(t, id[i][j][0], 0, 0);
					add(id[i][j][1], id[i][j][0], 1, w[i][j]), add(id[i][j][0], id[i][j][1], 0, -w[i][j]);
					add(id[i][j][1], id[i][j][0], 1, 0), add(id[i][j][0], id[i][j][1], 0, 0);
					add(id[i][j][2], id[i][j][0], 1, w[i][j]), add(id[i][j][0], id[i][j][2], 0, -w[i][j]);
					add(id[i][j][2], id[i][j][0], 1, 0), add(id[i][j][0], id[i][j][2], 0, 0);
				}
			}
		}
		for (int i=1; i<=n; ++i) {
			for (int j=1; j<=m; ++j) if ((i+j)&1) {
				for (int k=0; k<2; ++k) {
					int x=i+dlt[k][0], y=j+dlt[k][1];
					if (mp[x][y]) add(id[i][j][1], id[x][y][1], 1, 0), add(id[x][y][1], id[i][j][1], 0, 0);
				}
				for (int k=2; k<4; ++k) {
					int x=i+dlt[k][0], y=j+dlt[k][1];
					if (mp[x][y]) add(id[i][j][2], id[x][y][2], 1, 0), add(id[x][y][2], id[i][j][2], 0, 0);
				}
			}
		}
		while (spfa(s, t)) {
			ans1+=inc[t];
			ans2+=dis[t]*inc[t];
			for (int u=t; u!=s; u=e[back[u]^1].to) {
				e[back[u]].flw-=inc[t];
				e[back[u]^1].flw+=inc[t];
			}
		}
		// cout<<ans1<<endl;
		if (ans1!=cnt) puts("-1");
		else printf("%d\n", sum-ans2);
	}
}

signed main()
{
	freopen("roller.in", "r", stdin);
	freopen("roller.out", "w", stdout);

	// cout<<double(sizeof(bit)*5+sizeof(mp2))/1000/1000<<endl; exit(0);
	n=read(); m=read();
	for (int i=1; i<=n; ++i) for (int j=1; j<=m; ++j) mp[i][j]=read()^1;
	for (int i=1; i<=n; ++i) for (int j=1; j<=m; ++j) w[i][j]=read();
	// force::solve();
	task::solve();
	
	return 0;
}
posted @ 2022-03-06 21:27  Administrator-09  阅读(2)  评论(0编辑  收藏  举报