【luogu P7293】Sum of Distances P(线段树)(图论)

Sum of Distances P

题目链接:luogu P7293

题目大意

给你 k 个图,然后构建一个新图,新图点数是前面几个图点数的乘积。
每个点用 k 元组表示,然后如果有两个 k 元组,它们每一元在对于的图上都有边,那这两个点之间就连边。
然后问你跟 (1,1,...,1) 在同一个连通块的点每个点到它的最短路径之和。

思路

首先不难想到是可以在一个边反复横跳以等另一个图的边走到某个点的。
所以我们可以求出每个图每个点的奇偶最短路。

然后你考虑要怎么求答案,对于一个点 \((a_1,a_2,...,a_k)\),就是 \(\min\{\max\{dis_{a_i,0}\},\max\{dis_{a_i,1}\}\}\)

然后你发现两个 \(\max\) 可以单独搞,但是这个 \(\min\) 就不太友好,考虑搞一搞,变成:
\(\max\{dis_{a_i,0}\}+\max\{dis_{a_i,1}\}-\max\{\max\{dis_{a_i,0},dis_{a_i,1}\}\}\)

那三个可以分别搞,那要怎么搞呢?
你可以枚举 \(k\),使得第 \(k\) 元贡献的最终的答案。
那你就是枚举这一元选的,然后在其他里面看有多少个小于它的,然后个数乘积起来就是贡献。
考虑搞一个线段树维护每个图里面现在有多少个点小于,然后区间统计的就是区间值的乘积。
然后一开始全部插入,然后从大到小枚举点,统计之后就把贡献删掉。

然后就可以了。
(记得求第三个的时候如果取 \(\max\)\(INF\) 的话这个点不要算进去)

代码

#include<queue>
#include<cstdio>
#include<vector>
#include<cstring>
#include<iostream>
#define mo 1000000007
#define ll long long 
#define INF 0x3f3f3f3f3f3f3f3f

using namespace std;

struct node {
	int to, nxt;
}e[800001];
int k, n[50001], m[50001], count[500001], mxn[50001];
int le[200001], KK, x, y, dis[200001], cnt[500001];
int nmb[200001];
vector <ll> minn[50001][2];
vector <pair<int, int> > tong[400001];
bool in[200001];
ll ans;

struct abab {
	int x;
};
bool operator <(abab x, abab y) {
	return mxn[x.x] > mxn[y.x];
}
priority_queue <abab> mg;

void add(int x, int y) {
	e[++KK] = (node){y, le[x]}; le[x] = KK;
	e[++KK] = (node){x, le[y]}; le[y] = KK;
}

struct ztzt {
	int dis, now;
};
bool operator <(ztzt x, ztzt y) {
	return x.dis > y.dis;
}
priority_queue <ztzt> q;

void dij(int x) {//求奇偶最短路
	while (!q.empty()) q.pop();
	dis[0] = INF;
	for (int i = 1; i <= n[x]; i++) {
		dis[i] = dis[i + n[x]] = INF;
		in[i] = in[i + n[x]] = 0;
	}
	dis[1] = 0; q.push((ztzt){0, 1});
	while (!q.empty()) {
		int now = q.top().now;
		q.pop();
		if (in[now]) continue;
		in[now] = 1;
		for (int i = le[now]; i; i = e[i].nxt)
			if (!in[e[i].to] && dis[e[i].to] > dis[now] + 1) {
				dis[e[i].to] = dis[now] + 1;
				q.push((ztzt){dis[e[i].to], e[i].to});
			}
	}
}

struct XDtree {//线段树维护
	ll val[500001 << 2];
	
	void clean() {
		memset(val, 0, sizeof(val));
	}
	
	void up(int now) {
		val[now] = val[now << 1] * val[now << 1 | 1] % mo;
	} 
	
	void insert(int now, int l, int r, int pl, int va) {
		if (l == r) {
			val[now] = (val[now] + va + mo) % mo;
			return ;
		}
		
		int mid = (l + r) >> 1;
		if (pl <= mid) insert(now << 1, l, mid, pl, va);
			else insert(now << 1 | 1, mid + 1, r, pl, va);
		
		up(now); 
	}
	
	ll query(int now, int l, int r, int L, int R) {
		if (L > R) return 1;
		if (L <= l && r <= R) {
			return val[now];
		}
		
		int mid = (l + r) >> 1;
		ll re = 1;
		if (L <= mid) re = (re * query(now << 1, l, mid, L, R)) % mo;
		if (mid < R) re = (re * query(now << 1 | 1, mid + 1, r, L, R)) % mo;
		return re;
	}
}T;

ll get_max(int op) {
	ans = 0; int maxn = -1;
	for (int i = 1; i <= k; i++)
		for (int j = 0; j < n[i]; j++) {
			int vl = -1;
			if (op == 1) vl = minn[i][0][j];
				else if (op == 2) vl = minn[i][1][j];
					else {
						vl = max(minn[i][0][j], minn[i][1][j]);
//						if (vl == dis[0]) vl = min(minn[i][0][j], minn[i][1][j]);
					}
			if (vl == dis[0]) continue;//记得不要把这种给统计进去 nmb 里面
			nmb[i]++;
			tong[vl].push_back(make_pair(i, j));
			maxn = max(maxn, vl);
		}
	
	for (int i = 1; i <= k; i++) {
		T.insert(1, 1, k, i, nmb[i]);
	}
	for (int i = maxn; i >= 1; i--) {
		for (int j = 0; j < tong[i].size(); j++) {
			ans = (ans + i * (T.query(1, 1, k, 1, tong[i][j].first - 1) * T.query(1, 1, k, tong[i][j].first + 1, k) % mo) % mo) % mo;//剩余部分的乘积
			T.insert(1, 1, k, tong[i][j].first, -1);//减少
		}
	}
	
	for (int i = 1; i <= maxn; i++)
		tong[i].clear();
	T.clean();
	memset(nmb, 0, sizeof(nmb));
	return ans;
}

int main() {
	scanf("%d", &k);
	for (int i = 1; i <= k; i++) {
		scanf("%d %d", &n[i], &m[i]);
		
		KK = 0; for (int j = 1; j <= 2 * n[i]; j++) le[j] = 0;
		for (int j = 1; j <= m[i]; j++) {
			scanf("%d %d", &x, &y);
			add(x, y + n[i]); add(x + n[i], y);
		}
		
		dij(i);
		for (int j = 1; j <= n[i]; j++)
			minn[i][0].push_back(dis[j]), minn[i][1].push_back(dis[j + n[i]]);
	}
	
	printf("%lld", (get_max(1) + get_max(2) - get_max(3) + mo) % mo);
	
	return 0;
}
posted @ 2021-08-24 09:36  あおいSakura  阅读(30)  评论(0编辑  收藏  举报