day9T1改错记

题面

给出一棵\(n\)个节点的树,和\(m\)条非树边,第\(i\)条非树边连接\(u_i\)\(v_i\),有\(p_i\)的概率不出现,问期望有多少条非树边至多出现在一个简单环中。答案对\(998244353\)取模,输入的\(p\)也是取模后的概率。

\(n, m \le 1e6\)

解析瞎扯

对着标程看半天没看懂,一问才反应过来节点的标记实际上是它父边的标记……

于是就成了树上前缀和(积?)

代码

#include <cstring>
#include <cstdio>
#include <iostream>
#include <map>
#define mod 998244353ll

typedef long long LL;
typedef std::pair<int, int> pii;
const int maxn = (int)1e6 + 6;

LL qpower(LL x, LL y) {
	LL res = 1;
	while (y) {
		if (y & 1) res = res * x % mod;
		y >>= 1;
		x = x * x % mod;
	}
	return res;
}
struct Edge {
	int to, nxt;
	Edge(int _t = 0, int _n = 0):to(_t), nxt(_n) {}
} edge[maxn << 1];
int n, m, cnt, head[maxn], dep[maxn], f[maxn][21];
int x[maxn], y[maxn], lca[maxn], befx[maxn], befy[maxn], prob[maxn];
struct Tag {
	int val, zr;
	Tag(int v, int z) : val(v), zr(z) {}
	Tag(int x_ = 1) {
		if (!x_) val = zr = 1;
		else val = x_, zr = 0;
	}
	inline int v() { return zr ? 0 : val; }
	inline Tag operator *(const Tag &rhs) const { return Tag(1LL * val * rhs.val % mod, zr + rhs.zr); }
	inline Tag operator /(const Tag &rhs) const { return Tag(1LL * val * qpower(rhs.val, mod - 2) % mod, zr - rhs.zr); }
} t[maxn], sum[maxn], res[maxn];
std::map<pii, Tag> mp[maxn];
typedef std::map<pii, Tag>::iterator iter;

void add(int, int);
void dfs(int, int);
void calc(int);
char gc();
int read();

int main() {
	freopen("cactus.in", "r", stdin);
	freopen("cactus.out", "w", stdout);

	n = read(), m = read();
	for (int i = 1; i < n; ++i) {
		int u = read(), v = read();
		add(u, v), add(v, u);
	}
	dfs(1, 0);
	for (int i = 1; i <= m; ++i) {
		x[i] = read(), y[i] = read(), prob[i] = read();
		Tag inv = Tag(1) / Tag(prob[i]);
		if (dep[x[i]] < dep[y[i]]) std::swap(x[i], y[i]);
		int tmpx = x[i], tmpy = y[i];
		res[i] = inv;
		for (int j = 20; ~j; --j) if (dep[f[tmpx][j]] > dep[tmpy]) tmpx = f[tmpx][j];
		if (f[tmpx][0] == tmpy) {
			befx[i] = tmpx, sum[tmpx] = sum[tmpx] * prob[i];
			t[x[i]] = t[x[i]] * prob[i], t[y[i]] = t[y[i]] * inv;
		} else {
			if (dep[tmpx] > dep[tmpy]) tmpx = f[tmpx][0];
			for (int j = 20; ~j; --j) if (f[tmpx][j] != f[tmpy][j]) tmpx = f[tmpx][j], tmpy = f[tmpy][j];
			befx[i] = tmpx, befy[i] = tmpy;
			lca[i] = f[tmpx][0];
			t[x[i]] = t[x[i]] * prob[i], t[y[i]] = t[y[i]] * prob[i];
			t[lca[i]] = t[lca[i]] * inv * inv;
			sum[tmpx] = sum[tmpx] * prob[i], sum[tmpy] = sum[tmpy] * prob[i];
			if (tmpx > tmpy) std::swap(tmpx, tmpy);
			Tag &T = mp[lca[i]][std::make_pair(tmpx, tmpy)];
			T = T * inv;
		}
	}
	calc(1);
	LL ans = 0;
	for (int i = 1; i <= m; ++i) {
		Tag inv = res[i], ret = Tag(1 - prob[i]) * inv;
		if (!lca[i]) ans = (ans + (ret * sum[x[i]] / sum[befx[i]] * t[befx[i]]).v()) % mod;
		else {
			Tag l = sum[x[i]] / sum[befx[i]] * t[befx[i]], r = sum[y[i]] / sum[befy[i]] * t[befy[i]];
			Tag d = mp[lca[i]][std::make_pair(std::min(befy[i], befx[i]), std::max(befx[i], befy[i]))];
			ans = (ans + (ret * l * r * d).v()) % mod;
		}
	}
	printf("%lld\n", ans < 0 ? ans + mod : ans);

	return 0;
}
void dfs(int x, int fa) {
	f[x][0] = fa, dep[x] = dep[fa] + 1;
	for (int i = 1; i <= 20; ++i) f[x][i] = f[f[x][i - 1]][i - 1];
	for (int p = head[x]; p; p = edge[p].nxt) {
		int y = edge[p].to;
		if (y == fa) continue;
		dfs(y, x);
	}
}
void calc(int x) {
	if (f[x][0]) sum[x] = sum[f[x][0]] * sum[x];
	for (int p = head[x]; p; p = edge[p].nxt) {
		int y = edge[p].to;
		if (y == f[x][0]) continue;
		calc(y);
		t[x] = t[x] * t[y];
	}
}
inline void add(int bgn, int end) {
	edge[++cnt] = edge(end, head[bgn]);
	head[bgn] = cnt;
}
inline char gc() {
	static char buf[1000000], *p1, *p2;
	if (p1 == p2) p1 = (p2 = buf) + fread(buf, 1, 1000000, stdin);
	return p1 == p2 ? EOF : *p2++;
}
inline int read() {
	int res = 0; char ch = gc();
	while (ch < '0' || ch > '9') ch = gc();
	while (ch >= '0' && ch <= '9') res = (res << 1) + (res << 3) + ch - '0', ch = gc();
	return res;
}
//Rhein_E
posted @ 2019-03-12 18:43  Rhein_E  阅读(131)  评论(0编辑  收藏  举报