Rikka with Intersections of Paths

传送


翻译:Rikka有一棵包含\(n(1 \leqslant n \leqslant 3 \times 10 ^ 5)\)个节点的树,节点编号为\(1\)\(n\)。树上也标记了\(m(2 \leqslant m \leqslant 3 \times 10 ^ 5)\)条简单路径,第\(i\)条路径连接\(x_i\)\(y_i\)\((1\leqslant x_i,y_i \leqslant n)\),这些路径可能会有重复的。如果要在其中选择\(k\)路径,要求这\(k\)条路径至少有一个公共点,计算有多少种选择方案。输出其模\(10 ^ 9 + 7\)的结果。


一个直观的思路是记录每一个点被路径覆盖的次数\(c_i\),那么答案就是\(\sum_{i = 1} ^ {n} C_{c_i} ^ {k}\)
但这样统计显然会有重复的,所以可以用容斥的思想。
记一条路径的两个端点的lca为\(x_i\)。那么如果两条路径有交集,\(x_i\)\(x_j\)必定在交集之中。
\(p_i\)表示\(m\)条路径中lca在点\(i\)的数量,那么答案就是\(\sum _ {i = 1} ^ {n} C_{c_i} ^ {k} - C_{c_i - p_i} ^ {k}\)
注意,不是\(\sum_{i = 1} ^ {n} C_{p_i} ^ {k}\),因为这个式子表示的是路径必须相交在lca上,而实际上是一条路径上的某一个点和另一条路径的lca相交。


代码

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<queue>
#include<assert.h>
#include<ctime>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
#define forE(i, x, y) for(int i = head[x], y; ~i && (y = e[i].to); i = e[i].nxt)
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 3e5 + 5;
const ll mod = 1e9 + 7;
const int N = 19;
In ll read()
{
	ll ans = 0;
	char ch = getchar(), las = ' ';
	while(!isdigit(ch)) las = ch, ch = getchar();
	while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
	if(las == '-') ans = -ans;
	return ans;
}
In void write(ll x)
{
	if(x < 0) x = -x, putchar('-');
	if(x >= 10) write(x / 10);
	putchar(x % 10 + '0');
}
In void MYFILE()
{
#ifndef mrclr
	freopen(".in", "r", stdin);
	freopen(".out", "w", stdout);
#endif
}

int n, m, K;
struct Edge
{
	int nxt, to;
}e[maxn << 1];
int head[maxn], ecnt = -1;
In void addEdge(int x, int y)
{
	e[++ecnt] = (Edge){head[x], y};
	head[x] = ecnt;
}

ll fac[maxn], inv[maxn];
In ll quickpow(ll a, ll b)
{
	ll ret = 1;
	for(; b; b >>= 1, a = a * a % mod)
		if(b & 1) ret = ret * a % mod;
	return ret;
}
In ll C(int n, int m)
{
	if(m > n) return 0;
	return fac[n] * inv[m] % mod * inv[n - m] % mod;
}

int dep[maxn], fa[N + 2][maxn];
In void dfs(int now, int _f)
{
	for(int i = 1; (1 << i) <= dep[now]; ++i)
		fa[i][now] = fa[i - 1][fa[i - 1][now]];
	forE(i, now, v)
	{
		if(v == _f) continue;
		dep[v] = dep[now] + 1, fa[0][v] = now;
		dfs(v, now);
	}
}
In int lca(int x, int y)
{
	if(dep[x] < dep[y]) swap(x, y);
	for(int i = N; i >= 0; --i)
		if(dep[fa[i][x]] >= dep[y]) x = fa[i][x];
	if(x == y) return x;
	for(int i = N; i >= 0; --i)
		if(fa[i][x] ^ fa[i][y]) x = fa[i][x], y = fa[i][y];
	return fa[0][x];
}

int dif[maxn], num[maxn];
In void dfs2(int now, int _f)
{
	forE(i, now, v)
		if(v ^ _f) dfs2(v, now), dif[now] += dif[v];
}

In void work()
{
	dep[1] = 1, dfs(1, 0);
	for(int i = 1; i <= m; ++i)
	{
		int x = read(), y = read();
		int z = lca(x, y);
		dif[x]++, dif[y]++, dif[z]--, dif[fa[0][z]]--;
		num[z]++;
	}
	dfs2(1, 0);
	ll ans = 0;
	for(int i = 1; i <= n; ++i)
		ans = (ans + C(dif[i], K) - C(dif[i] - num[i], K) + mod) % mod;
	write(ans), enter;
}

In void init()
{
	Mem(head, -1), ecnt = -1;
	Mem(num, 0), Mem(dif, 0), Mem(fa, 0);

}

int main()
{
//	MYFILE();
	int T = read();
	fac[0] = inv[0] = 1;
	for(int i = 1; i < maxn; ++i)
	{
		fac[i] = fac[i - 1] * i % mod;	
		inv[i] = quickpow(fac[i], mod - 2);
	}	
	while(T--)
	{
		n = read(), m = read(), K = read();
		init();
		for(int i = 1; i < n; ++i)
		{
			int x = read(), y = read();
			addEdge(x, y), addEdge(y, x);
		}
		work();
	}
	return 0;	
}
posted @ 2020-10-03 15:21  mrclr  阅读(363)  评论(0编辑  收藏  举报