JZOJ 100019.A

\(\text{Problem}\)


\(\text{Solution}\)

把形如 \((a,ka)\) 的路径提出来
那么覆盖这些路径的路径为不合法路径
如果能不重不漏的找出这些路径,然后用总路径减去就是答案
为了方便计算,我们限定路径用 \(dfn\) 序表示 \((x,y)\) ,并规定 \(x < y\)
即树上两点构成的路径 \((x,y)\) 满足 \(dfn[x] < dfn[y]\)

然后如何确定那些路径 \((a,b)\) 覆盖了最先找出来的路径 \((u,v)\)
其实很好办,自己画画图就知道了
其中要分两类讨论,记 \(end_x\) 为子树 \(x\)\(dfn\) 序最大的点的 \(dfn\) 序,即 \(end_x = dfn_x + siz_x - 1\)
那么

于是我们确定了不合法路径 \((a,b)\) 的范围,那怎么去掉重复路径呢?
很妙啊!
因为路径像是平面上的有序数对,于是我们把它弄到平面上,然后发现不合法路径的范围是一个又一个矩阵
那么总数就是矩阵面积的并
扫描线解决即可

\(\text{Code}\)

#include<cstdio>
#include<algorithm>
#define LL long long
#define ls (p << 1)
#define rs (ls | 1)
using namespace std;

const int N = 1e5 + 5;
int n, h[N], m;

struct line{
	int x, y0, y1, v;
}l[4000005];
inline bool cmp(line x, line y){return x.x < y.x ? 1 :(x.x == y.x ? x.v < y.v : 0);}

struct edge{int to, nxt;}e[N * 2];
inline void add(int x, int y)
{
	static int tot = 0;
	e[++tot] = edge{y, h[x]}, h[x] = tot;
}

int dep[N], f[N][20], dfn[N], siz[N];
void dfs(int x)
{
	static int dfc = 0;
	dfn[x] = ++dfc, siz[x] = 1;
	for(int i = 1; i <= 17; i++)
	if (f[x][i - 1]) f[x][i] = f[f[x][i - 1]][i - 1];
	else break;
	for(int i = h[x]; i; i = e[i].nxt)
	{
		int v = e[i].to;
		if (dep[v]) continue;
		dep[v] = dep[x] + 1, f[v][0] = x, dfs(v), siz[x] += siz[v];
	}
}

int sum[N << 2], tag[N << 2];
inline void pushup(int l, int r, int p)
{
	if (tag[p] > 0) sum[p] = r - l + 1;
	else if (l == r) sum[p] = 0;
	else sum[p] = sum[ls] + sum[rs];
}
void update(int l, int r, int p, int x, int y, int v)
{
	if (x > r || y < l) return;
	if (x <= l && r <= y)
	{
		tag[p] += v;
		pushup(l, r, p);
		return;
	}
	int mid = (l + r) >> 1;
	if (x <= mid) update(l, mid, ls, x, y, v);
	if (y > mid) update(mid + 1, r, rs, x, y, v);
	pushup(l, r, p);
} 

int main() 
{
	freopen("a.in", "r", stdin), freopen("a.out", "w", stdout);
	scanf("%d", &n);
	for(int i = 1, x, y; i < n; i++) scanf("%d%d", &x, &y), add(x, y), add(y, x);
	dep[1] = 1, dfs(1);
	for(int i = 1, x, y, t; i <= n; i++)
		for(int j = i + i; j <= n; j += i)
		{
			x = i, y = j;
			if (dfn[x] > dfn[y]) swap(x, y);
			if (dfn[x] + siz[x] - 1 >= dfn[y])
			{
				t = y;
				for(int k = 17; k >= 0; k--)
				if (f[t][k] && dep[f[t][k]] > dep[x]) t = f[t][k];
				if (dfn[t] > 1)
				{
					l[++m] = line{1, dfn[y], dfn[y] + siz[y] - 1, 1};
					l[++m] = line{dfn[t], dfn[y], dfn[y] + siz[y] - 1, -1};
				}
				if (dfn[t] + siz[t] <= n)
				{
					l[++m] = line{dfn[y], dfn[t] + siz[t], n, 1};
					l[++m] = line{dfn[y] + siz[y], dfn[t] + siz[t], n, -1};
				}
			}
			else{
				l[++m] = line{dfn[x], dfn[y], dfn[y] + siz[y] - 1, 1};
				l[++m] = line{dfn[x] + siz[x], dfn[y], dfn[y] + siz[y] - 1, -1};
			}
		} 
	sort(l + 1, l + m + 1, cmp);
	LL ans = 0;
	for(int i = 1, j; i <= m; i++)
	{	
		ans += 1LL * sum[1] * (l[i].x - l[i - 1].x);
		for(j = i; j <= m && l[j].x == l[i].x; j++) update(1, n, 1, l[j].y0, l[j].y1, l[j].v);
		i = j - 1;
	}
	printf("%lld\n", 1LL * n * (n - 1) / 2 - ans);
}
posted @ 2021-06-07 21:05  leiyuanze  阅读(41)  评论(0编辑  收藏  举报