题解 求和

传送门

保分题又爆零了,数不清第几次了
以后保分题无论如何要跑对拍! 三道题辛辛苦苦骗来的分抵不住一道傻逼题爆零

树上lca,求就好了,就是细节有点多

2021/06/23 upd: 被洛谷上hack数据卡掉了……原来是倍增2的次幂开小了
不过发现求树上路径长的公式可以优化一下,若\(a\)\(b\)的lca为\(t\),则有

\[ans = ((sum[dep[a]-1]-sum[dep[t]-1]+sum[dep[b]-1]-sum[max(dep[t]-2, 0)])\%mod+mod)\%mod \]

Code:

#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 300010
#define ll long long 
#define ld long double
#define usd unsigned
#define ull unsigned long long
//#define int long long 
#define max(a, b) ((a)>(b)?(a):(b))

#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
char buf[1<<21], *p1=buf, *p2=buf;
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, m;
int head[N], size, dep[N], fa[N][22], lg[N], mdep;
ll sum[53][N];
const ll mod=998244353;
bool vis[55];
struct edge{int to, next;}; edge* e;
inline void add(int s, int t) {edge *k=&e[++size]; k->to=t; k->next=head[s]; head[s]=size;}

ll qpow(ll a, ll b) {
	ll ans=1;
	while (b) {
		if (b&1) ans=ans*a%mod;
		a=a*a%mod; b>>=1;
	}
	return ans;
}

void dfs(int u, int pa) {
	//cout<<"dfs "<<u<<endl;
	for (int i=1; i<=19; ++i)
		if (dep[u]>=(1<<i)) fa[u][i] = fa[fa[u][i-1]][i-1];
		else break;
	for (int i=head[u],v; i; i=e[i].next) {
		v = e[i].to;
		if (v!=pa) dep[v]=dep[u]+1, fa[v][0]=u, dfs(v, u), mdep=max(mdep, dep[v]);
	}
}

int lca(int a, int b) {
	if (dep[a]<dep[b]) swap(a, b);
	while (dep[a]>dep[b]) a=fa[a][lg[dep[a]-dep[b]]-1];
	if (a==b) return a;
	for (int i=lg[dep[a]]-1; i>=0; --i) 
		if (fa[a][i]!=fa[b][i]) 
			a=fa[a][i], b=fa[b][i];
	return fa[a][0];
}

signed main()
{
	#ifdef DEBUG
	freopen("1.in", "r", stdin);
	#endif
	int a, b, k, t;
	
	n=read();
	e = new edge[n*2+10];
	for (int i=1,u,v; i<n; ++i) {u=read(); v=read(); add(u, v); add(v, u);}
	for (int i=1; i<=n; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
	dep[1]=1;
	dfs(1, 0);
	m=read();
	for (int i=1; i<=m; ++i) {
		a=read(); b=read(); k=read();
		//cout<<"ab: "<<a<<' '<<b<<endl;
		t=lca(a, b);
		//cout<<"t: "<<t<<endl;
		if (!vis[k]) {
			sum[k][1]=1;
			for (int i=2; i<=mdep; ++i) sum[k][i]=(sum[k][i-1]+qpow(i, k))%mod;
			vis[k]=1;
		}
		printf("%lld\n", ((sum[k][dep[a]-1]-sum[k][dep[t]-1]+sum[k][dep[b]-1]-sum[k][max(dep[t]-2, 0)])%mod+mod)%mod);
	}

	return 0;
}
posted @ 2021-06-22 21:28  Administrator-09  阅读(29)  评论(0编辑  收藏  举报