题解 鼠

传送门

  • 见到计算所有点对间的贡献之和优先考虑分治

注意到这样一个事情:
因为只有三行,所以若在一个点 A 和 B 中间画一条线 mid 的话
那么从 A 到 B 的最短路一定恰好跨过 mid 一次
那么分治,可以 dij 处理出 \(mid\)\([l, mid]\) 中的点的距离和 \(mid+1\)\([mid+1, r]\) 的距离
但是一个常数更小的做法是对每个分治区间计算 \(l\) 那三个点和 \(r\) 那三个点到区间内所有点的距离
一个区间的这个东西可以从左右两个子区间合并上来
然后可以这个样子:
image
复杂度 \(O(n\log^2 n)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f3f3f3f3f
#define N 100010
#define fir first
#define sec second
#define pb push_back
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n;
ll a[4][N];
const ll mod=1e9+7;

namespace force{
	bool vis[N];
	ll dis[N], val[N], ans;
	int head[N], id[4][N], tot, ecnt;
	struct edge{int to, next; ll val;}e[N<<1];
	const int dlt[][2]={{-1,0},{0,1},{1,0},{0,-1}};
	inline void add(int s, int t, ll w) {e[++ecnt]={t, head[s], w}; head[s]=ecnt;}
	priority_queue<pair<ll, int>> q;
	void solve() {
		memset(head, -1, sizeof(head));
		for (int i=1; i<=3; ++i) for (int j=1; j<=n; ++j) val[id[i][j]=++tot]=a[i][j];
		for (int i=1; i<=3; ++i) {
			for (int j=1; j<=n; ++j) {
				for (int k=0; k<4; ++k) {
					int x=i+dlt[k][0], y=j+dlt[k][1];
					if (x>=1&&x<=3&&y>=1&&y<=n) add(id[i][j], id[x][y], a[x][y]);
				}
			}
		}
		for (int i=1; i<=tot; ++i) {
			for (int j=1; j<=tot; ++j) dis[j]=INF, vis[j]=0;
			dis[i]=0; q.push({0, i});
			while (q.size()) {
				int u=q.top().sec; q.pop();
				if (vis[u]) continue;
				vis[u]=1;
				for (int i=head[u],v; ~i; i=e[i].next) {
					v = e[i].to;
					if (dis[v]>dis[u]+e[i].val) {
						dis[v]=dis[u]+e[i].val;
						q.push({-dis[v], v});
					}
				}
			}
			for (int j=1; j<=tot; ++j) if (j!=i) ans=(ans+val[i]+dis[j])%mod;
		}
		cout<<ans<<endl;
	}
}

namespace task{
	#undef unix
	int xsiz, ysiz, top;
	vector<pair<ll, ll>> add[N*3], que[N*3];
	struct point{ll fir, sec, val;}sta[N*3];
	ll unix[N*3], uniy[N*3], bit1[N*3], bit2[N*3];
	ll disl[21][4][4][N], disr[21][4][4][N], tl[N], tr[N], ans;
	inline void upd(ll* bit, int i, ll dat) {for (; i<=ysiz+1; i+=i&-i) bit[i]=(bit[i]+dat)%mod;}
	inline ll query(ll* bit, int i) {ll ans=0; for (; i; i-=i&-i) ans=(ans+bit[i])%mod; return ans;}
	void solve(int l, int r, int dep) {
		// cout<<"solve: "<<l<<' '<<r<<endl;
		if (l==r) {
			for (int i=1; i<=3; ++i) disl[dep][1][i][l]=disl[dep][1][i-1][l]+a[i][l];
			for (int i=1; i<=3; ++i) disl[dep][2][i][l]=a[i][l]+(i==2?0:a[2][l]);
			for (int i=3; i; --i) disl[dep][3][i][l]=(i==3?0:disl[dep][3][i+1][l])+a[i][l];
			for (int i=1; i<=3; ++i) disr[dep][1][i][r]=disr[dep][1][i-1][r]+a[i][r];
			for (int i=1; i<=3; ++i) disr[dep][2][i][r]=a[i][r]+(i==2?0:a[2][r]);
			for (int i=3; i; --i) disr[dep][3][i][r]=(i==3?0:disr[dep][3][i+1][r])+a[i][r];
			disl[dep][1][3][l]=min(disl[dep][1][3][l], min(tl[l], tr[r]));
			disl[dep][3][1][l]=min(disl[dep][3][1][l], min(tl[l], tr[r]));
			disr[dep][1][3][r]=min(disr[dep][1][3][r], min(tl[l], tr[r]));
			disr[dep][3][1][r]=min(disr[dep][3][1][r], min(tl[l], tr[r]));
			for (int i=1; i<=3; ++i) for (int j=i+1; j<=3; ++j) ans=(ans+disl[dep][i][j][l])%mod;
			return ;
		}
		int mid=(l+r)>>1;
		solve(l, mid, dep+1); solve(mid+1, r, dep+1);
		// for (int i=1; i<=3; ++i)
		// 	for (int j=1; j<=3; ++j)
		// 		for (int k=l; k<=mid; ++k)
		// 			printf("disl[%d][%d][%d]=%lld\n", i, j, k, disl[dep+1][i][j][k]);
		for (int now=1; now<=3; ++now) {
			// cout<<"now: "<<now<<endl;
			pair<int, int> other={0, 0};
			for (int j=1; j<=3; ++j) if (j!=now) {
				if (other.fir) other.sec=j;
				else other.fir=j;
			}
			// for (int x1=1; x1<=3; ++x1)
			// 	for (int y1=l; y1<=mid; ++y1)
			// 		for (int x2=1; x2<=3; ++x2)
			// 			for (int y2=mid+1; y2<=r; ++y2) {
			// 				if (disr[dep+1][now][x1][y1]-disr[dep+1][other.fir][x1][y1]<=disl[dep+1][other.fir][x2][y2]-disl[dep+1][now][x2][y2]-(now>other.fir)
			// 					&& disr[dep+1][now][x1][y1]-disr[dep+1][other.sec][x1][y1]<=disl[dep+1][other.sec][x2][y2]-disl[dep+1][now][x2][y2]-(now>other.sec))
			// 					ans=(ans+disr[dep+1][now][x1][y1]+disl[dep+1][now][x2][y2])%mod;
			// 			}
			top=xsiz=ysiz=0;
			for (int x=1; x<=3; ++x)
				for (int y=l; y<=mid; ++y)
					sta[++top]={disr[dep+1][now][x][y]-disr[dep+1][other.fir][x][y], disr[dep+1][now][x][y]-disr[dep+1][other.sec][x][y], disr[dep+1][now][x][y]};
			// cout<<"sta: "; for (int i=1; i<=top; ++i) cout<<"("<<sta[i].fir<<','<<sta[i].sec<<','<<sta[i].val<<") "; cout<<endl;
			for (int i=1; i<=top; ++i) unix[++xsiz]=sta[i].fir, uniy[++ysiz]=sta[i].sec;
			sort(unix+1, unix+xsiz+1); xsiz=unique(unix+1, unix+xsiz+1)-unix-1;
			sort(uniy+1, uniy+ysiz+1); ysiz=unique(uniy+1, uniy+ysiz+1)-uniy-1;
			for (int i=1; i<=top; ++i) add[lower_bound(unix+1, unix+xsiz+1, sta[i].fir)-unix].pb({sta[i].sec, sta[i].val});
			// cout<<"unix: "; for (int i=1; i<=xsiz; ++i) cout<<unix[i]<<' '; cout<<endl;
			for (int x=1; x<=3; ++x) {
				for (int y=mid+1; y<=r; ++y) {
					pair<ll, ll> tem={disl[dep+1][other.fir][x][y]-disl[dep+1][now][x][y]-(now>other.fir), disl[dep+1][other.sec][x][y]-disl[dep+1][now][x][y]-(now>other.sec)};
					que[upper_bound(unix+1, unix+xsiz+1, tem.fir)-unix-1].pb({tem.sec, disl[dep+1][now][x][y]});
					// cout<<"tem.fir = "<<tem.fir<<" and it's inserted at "<<upper_bound(unix+1, unix+xsiz+1, tem.fir)-unix-1<<endl;
					// for (int i=1; i<=top; ++i) if (sta[i].fir<=tem.fir && sta[i].sec<=tem.sec)
					// 	ans=(ans+sta[i].val+disl[dep+1][now][x][y])%mod;
				}
			}
			for (int i=1; i<=xsiz+1; ++i) {
				// cout<<"i: "<<i<<' '<<unix[i]<<endl;
				// cout<<"que: "; for (auto it:que[i]) cout<<it<<' ' ; cout<<endl;
				for (auto it:add[i]) {
					upd(bit1, lower_bound(uniy+1, uniy+ysiz+1, it.fir)-uniy, it.sec);
					upd(bit2, lower_bound(uniy+1, uniy+ysiz+1, it.fir)-uniy, 1);
				}
				for (auto it:que[i]) ans=(ans+query(bit1, upper_bound(uniy+1, uniy+ysiz+1, it.fir)-uniy-1)+it.sec%mod*query(bit2, upper_bound(uniy+1, uniy+ysiz+1, it.fir)-uniy-1))%mod;
				// for (auto it:que[i]) cout<<"que for "<<it<<" result "<<query(bit1, upper_bound(uniy+1, uniy+ysiz+1, it)-uniy-1)<<' '<<query(bit2, upper_bound(uniy+1, uniy+ysiz+1, it)-uniy-1)<<endl;
			}
			for (int i=1; i<=xsiz+1; ++i) for (auto it:add[i]) {
				upd(bit1, lower_bound(uniy+1, uniy+ysiz+1, it.fir)-uniy, -it.sec);
				upd(bit2, lower_bound(uniy+1, uniy+ysiz+1, it.fir)-uniy, -1);
			}
			for (int i=1; i<=xsiz+1; ++i) add[i].clear(), que[i].clear();
		}
		// for (int x1=1; x1<=3; ++x1)
		// 	for (int y1=l; y1<=mid; ++y1)
		// 		for (int x2=1; x2<=3; ++x2)
		// 			for (int y2=mid+1; y2<=r; ++y2) {
		// 				ll dis=INF;
		// 				for (int k=1; k<=3; ++k)
		// 					dis=min(dis, disr[dep+1][k][x1][y1]+disl[dep+1][k][x2][y2]);
		// 				ans=(ans+dis)%mod;
		// 			}
		for (int i=1; i<=3; ++i)
			for (int x=1; x<=3; ++x)
				for (int y=mid+1; y<=r; ++y)
					disr[dep][i][x][y]=disr[dep+1][i][x][y];
		for (int i=1; i<=3; ++i)
			for (int x=1; x<=3; ++x)
				for (int y=l; y<=mid; ++y)
					disl[dep][i][x][y]=disl[dep+1][i][x][y];
		for (int i=1; i<=3; ++i)
			for (int x=1; x<=3; ++x)
				for (int y=mid+1; y<=r; ++y) {
					disl[dep][i][x][y]=INF;
					for (int j=1; j<=3; ++j)
						disl[dep][i][x][y]=min(disl[dep][i][x][y], disl[dep+1][i][j][mid]+disl[dep+1][j][x][y]);
				}
		for (int i=1; i<=3; ++i)
			for (int x=1; x<=3; ++x)
				for (int y=l; y<=mid; ++y) {
					disr[dep][i][x][y]=INF;
					for (int j=1; j<=3; ++j)
						disr[dep][i][x][y]=min(disr[dep][i][x][y], disr[dep+1][i][j][mid+1]+disr[dep+1][j][x][y]);
				}
	}
	void solve() {
		// cout<<double(sizeof(disl)*2)/1000/1000<<endl;
		memset(tl, 0x3f, sizeof(tl));
		memset(tr, 0x3f, sizeof(tr));
		for (int i=1; i<=n; ++i) tl[i]=a[1][i]+a[3][i]+min(tl[i-1], a[2][i]);
		for (int i=n; i; --i) tr[i]=a[1][i]+a[3][i]+min(tr[i+1], a[2][i]);
		solve(1, n, 1);
		printf("%lld\n", ans*2%mod);
	}
}

signed main()
{
	freopen("mouse.in", "r", stdin);
	freopen("mouse.out", "w", stdout);

	n=read();
	for (int i=1; i<=3; ++i) for (int j=1; j<=n; ++j) a[i][j]=read();
	// force::solve();
	task::solve();

	return 0;
}
posted @ 2022-06-04 21:16  Administrator-09  阅读(3)  评论(0编辑  收藏  举报