题解 鼠
- 见到计算所有点对间的贡献之和优先考虑分治
注意到这样一个事情:
因为只有三行,所以若在一个点 A 和 B 中间画一条线 mid 的话
那么从 A 到 B 的最短路一定恰好跨过 mid 一次
那么分治,可以 dij 处理出 \(mid\) 到 \([l, mid]\) 中的点的距离和 \(mid+1\) 到 \([mid+1, r]\) 的距离
但是一个常数更小的做法是对每个分治区间计算 \(l\) 那三个点和 \(r\) 那三个点到区间内所有点的距离
一个区间的这个东西可以从左右两个子区间合并上来
然后可以这个样子:
复杂度 \(O(n\log^2 n)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f3f3f3f3f
#define N 100010
#define fir first
#define sec second
#define pb push_back
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n;
ll a[4][N];
const ll mod=1e9+7;
namespace force{
bool vis[N];
ll dis[N], val[N], ans;
int head[N], id[4][N], tot, ecnt;
struct edge{int to, next; ll val;}e[N<<1];
const int dlt[][2]={{-1,0},{0,1},{1,0},{0,-1}};
inline void add(int s, int t, ll w) {e[++ecnt]={t, head[s], w}; head[s]=ecnt;}
priority_queue<pair<ll, int>> q;
void solve() {
memset(head, -1, sizeof(head));
for (int i=1; i<=3; ++i) for (int j=1; j<=n; ++j) val[id[i][j]=++tot]=a[i][j];
for (int i=1; i<=3; ++i) {
for (int j=1; j<=n; ++j) {
for (int k=0; k<4; ++k) {
int x=i+dlt[k][0], y=j+dlt[k][1];
if (x>=1&&x<=3&&y>=1&&y<=n) add(id[i][j], id[x][y], a[x][y]);
}
}
}
for (int i=1; i<=tot; ++i) {
for (int j=1; j<=tot; ++j) dis[j]=INF, vis[j]=0;
dis[i]=0; q.push({0, i});
while (q.size()) {
int u=q.top().sec; q.pop();
if (vis[u]) continue;
vis[u]=1;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (dis[v]>dis[u]+e[i].val) {
dis[v]=dis[u]+e[i].val;
q.push({-dis[v], v});
}
}
}
for (int j=1; j<=tot; ++j) if (j!=i) ans=(ans+val[i]+dis[j])%mod;
}
cout<<ans<<endl;
}
}
namespace task{
#undef unix
int xsiz, ysiz, top;
vector<pair<ll, ll>> add[N*3], que[N*3];
struct point{ll fir, sec, val;}sta[N*3];
ll unix[N*3], uniy[N*3], bit1[N*3], bit2[N*3];
ll disl[21][4][4][N], disr[21][4][4][N], tl[N], tr[N], ans;
inline void upd(ll* bit, int i, ll dat) {for (; i<=ysiz+1; i+=i&-i) bit[i]=(bit[i]+dat)%mod;}
inline ll query(ll* bit, int i) {ll ans=0; for (; i; i-=i&-i) ans=(ans+bit[i])%mod; return ans;}
void solve(int l, int r, int dep) {
// cout<<"solve: "<<l<<' '<<r<<endl;
if (l==r) {
for (int i=1; i<=3; ++i) disl[dep][1][i][l]=disl[dep][1][i-1][l]+a[i][l];
for (int i=1; i<=3; ++i) disl[dep][2][i][l]=a[i][l]+(i==2?0:a[2][l]);
for (int i=3; i; --i) disl[dep][3][i][l]=(i==3?0:disl[dep][3][i+1][l])+a[i][l];
for (int i=1; i<=3; ++i) disr[dep][1][i][r]=disr[dep][1][i-1][r]+a[i][r];
for (int i=1; i<=3; ++i) disr[dep][2][i][r]=a[i][r]+(i==2?0:a[2][r]);
for (int i=3; i; --i) disr[dep][3][i][r]=(i==3?0:disr[dep][3][i+1][r])+a[i][r];
disl[dep][1][3][l]=min(disl[dep][1][3][l], min(tl[l], tr[r]));
disl[dep][3][1][l]=min(disl[dep][3][1][l], min(tl[l], tr[r]));
disr[dep][1][3][r]=min(disr[dep][1][3][r], min(tl[l], tr[r]));
disr[dep][3][1][r]=min(disr[dep][3][1][r], min(tl[l], tr[r]));
for (int i=1; i<=3; ++i) for (int j=i+1; j<=3; ++j) ans=(ans+disl[dep][i][j][l])%mod;
return ;
}
int mid=(l+r)>>1;
solve(l, mid, dep+1); solve(mid+1, r, dep+1);
// for (int i=1; i<=3; ++i)
// for (int j=1; j<=3; ++j)
// for (int k=l; k<=mid; ++k)
// printf("disl[%d][%d][%d]=%lld\n", i, j, k, disl[dep+1][i][j][k]);
for (int now=1; now<=3; ++now) {
// cout<<"now: "<<now<<endl;
pair<int, int> other={0, 0};
for (int j=1; j<=3; ++j) if (j!=now) {
if (other.fir) other.sec=j;
else other.fir=j;
}
// for (int x1=1; x1<=3; ++x1)
// for (int y1=l; y1<=mid; ++y1)
// for (int x2=1; x2<=3; ++x2)
// for (int y2=mid+1; y2<=r; ++y2) {
// if (disr[dep+1][now][x1][y1]-disr[dep+1][other.fir][x1][y1]<=disl[dep+1][other.fir][x2][y2]-disl[dep+1][now][x2][y2]-(now>other.fir)
// && disr[dep+1][now][x1][y1]-disr[dep+1][other.sec][x1][y1]<=disl[dep+1][other.sec][x2][y2]-disl[dep+1][now][x2][y2]-(now>other.sec))
// ans=(ans+disr[dep+1][now][x1][y1]+disl[dep+1][now][x2][y2])%mod;
// }
top=xsiz=ysiz=0;
for (int x=1; x<=3; ++x)
for (int y=l; y<=mid; ++y)
sta[++top]={disr[dep+1][now][x][y]-disr[dep+1][other.fir][x][y], disr[dep+1][now][x][y]-disr[dep+1][other.sec][x][y], disr[dep+1][now][x][y]};
// cout<<"sta: "; for (int i=1; i<=top; ++i) cout<<"("<<sta[i].fir<<','<<sta[i].sec<<','<<sta[i].val<<") "; cout<<endl;
for (int i=1; i<=top; ++i) unix[++xsiz]=sta[i].fir, uniy[++ysiz]=sta[i].sec;
sort(unix+1, unix+xsiz+1); xsiz=unique(unix+1, unix+xsiz+1)-unix-1;
sort(uniy+1, uniy+ysiz+1); ysiz=unique(uniy+1, uniy+ysiz+1)-uniy-1;
for (int i=1; i<=top; ++i) add[lower_bound(unix+1, unix+xsiz+1, sta[i].fir)-unix].pb({sta[i].sec, sta[i].val});
// cout<<"unix: "; for (int i=1; i<=xsiz; ++i) cout<<unix[i]<<' '; cout<<endl;
for (int x=1; x<=3; ++x) {
for (int y=mid+1; y<=r; ++y) {
pair<ll, ll> tem={disl[dep+1][other.fir][x][y]-disl[dep+1][now][x][y]-(now>other.fir), disl[dep+1][other.sec][x][y]-disl[dep+1][now][x][y]-(now>other.sec)};
que[upper_bound(unix+1, unix+xsiz+1, tem.fir)-unix-1].pb({tem.sec, disl[dep+1][now][x][y]});
// cout<<"tem.fir = "<<tem.fir<<" and it's inserted at "<<upper_bound(unix+1, unix+xsiz+1, tem.fir)-unix-1<<endl;
// for (int i=1; i<=top; ++i) if (sta[i].fir<=tem.fir && sta[i].sec<=tem.sec)
// ans=(ans+sta[i].val+disl[dep+1][now][x][y])%mod;
}
}
for (int i=1; i<=xsiz+1; ++i) {
// cout<<"i: "<<i<<' '<<unix[i]<<endl;
// cout<<"que: "; for (auto it:que[i]) cout<<it<<' ' ; cout<<endl;
for (auto it:add[i]) {
upd(bit1, lower_bound(uniy+1, uniy+ysiz+1, it.fir)-uniy, it.sec);
upd(bit2, lower_bound(uniy+1, uniy+ysiz+1, it.fir)-uniy, 1);
}
for (auto it:que[i]) ans=(ans+query(bit1, upper_bound(uniy+1, uniy+ysiz+1, it.fir)-uniy-1)+it.sec%mod*query(bit2, upper_bound(uniy+1, uniy+ysiz+1, it.fir)-uniy-1))%mod;
// for (auto it:que[i]) cout<<"que for "<<it<<" result "<<query(bit1, upper_bound(uniy+1, uniy+ysiz+1, it)-uniy-1)<<' '<<query(bit2, upper_bound(uniy+1, uniy+ysiz+1, it)-uniy-1)<<endl;
}
for (int i=1; i<=xsiz+1; ++i) for (auto it:add[i]) {
upd(bit1, lower_bound(uniy+1, uniy+ysiz+1, it.fir)-uniy, -it.sec);
upd(bit2, lower_bound(uniy+1, uniy+ysiz+1, it.fir)-uniy, -1);
}
for (int i=1; i<=xsiz+1; ++i) add[i].clear(), que[i].clear();
}
// for (int x1=1; x1<=3; ++x1)
// for (int y1=l; y1<=mid; ++y1)
// for (int x2=1; x2<=3; ++x2)
// for (int y2=mid+1; y2<=r; ++y2) {
// ll dis=INF;
// for (int k=1; k<=3; ++k)
// dis=min(dis, disr[dep+1][k][x1][y1]+disl[dep+1][k][x2][y2]);
// ans=(ans+dis)%mod;
// }
for (int i=1; i<=3; ++i)
for (int x=1; x<=3; ++x)
for (int y=mid+1; y<=r; ++y)
disr[dep][i][x][y]=disr[dep+1][i][x][y];
for (int i=1; i<=3; ++i)
for (int x=1; x<=3; ++x)
for (int y=l; y<=mid; ++y)
disl[dep][i][x][y]=disl[dep+1][i][x][y];
for (int i=1; i<=3; ++i)
for (int x=1; x<=3; ++x)
for (int y=mid+1; y<=r; ++y) {
disl[dep][i][x][y]=INF;
for (int j=1; j<=3; ++j)
disl[dep][i][x][y]=min(disl[dep][i][x][y], disl[dep+1][i][j][mid]+disl[dep+1][j][x][y]);
}
for (int i=1; i<=3; ++i)
for (int x=1; x<=3; ++x)
for (int y=l; y<=mid; ++y) {
disr[dep][i][x][y]=INF;
for (int j=1; j<=3; ++j)
disr[dep][i][x][y]=min(disr[dep][i][x][y], disr[dep+1][i][j][mid+1]+disr[dep+1][j][x][y]);
}
}
void solve() {
// cout<<double(sizeof(disl)*2)/1000/1000<<endl;
memset(tl, 0x3f, sizeof(tl));
memset(tr, 0x3f, sizeof(tr));
for (int i=1; i<=n; ++i) tl[i]=a[1][i]+a[3][i]+min(tl[i-1], a[2][i]);
for (int i=n; i; --i) tr[i]=a[1][i]+a[3][i]+min(tr[i+1], a[2][i]);
solve(1, n, 1);
printf("%lld\n", ans*2%mod);
}
}
signed main()
{
freopen("mouse.in", "r", stdin);
freopen("mouse.out", "w", stdout);
n=read();
for (int i=1; i<=3; ++i) for (int j=1; j<=n; ++j) a[i][j]=read();
// force::solve();
task::solve();
return 0;
}