题解 求和
保分题又爆零了,数不清第几次了
以后保分题无论如何要跑对拍! 三道题辛辛苦苦骗来的分抵不住一道傻逼题爆零
树上lca,求就好了,就是细节有点多
2021/06/23 upd: 被洛谷上hack数据卡掉了……原来是倍增2的次幂开小了
不过发现求树上路径长的公式可以优化一下,若\(a\), \(b\)的lca为\(t\),则有
\[ans = ((sum[dep[a]-1]-sum[dep[t]-1]+sum[dep[b]-1]-sum[max(dep[t]-2, 0)])\%mod+mod)\%mod
\]
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 300010
#define ll long long
#define ld long double
#define usd unsigned
#define ull unsigned long long
//#define int long long
#define max(a, b) ((a)>(b)?(a):(b))
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
char buf[1<<21], *p1=buf, *p2=buf;
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, m;
int head[N], size, dep[N], fa[N][22], lg[N], mdep;
ll sum[53][N];
const ll mod=998244353;
bool vis[55];
struct edge{int to, next;}; edge* e;
inline void add(int s, int t) {edge *k=&e[++size]; k->to=t; k->next=head[s]; head[s]=size;}
ll qpow(ll a, ll b) {
ll ans=1;
while (b) {
if (b&1) ans=ans*a%mod;
a=a*a%mod; b>>=1;
}
return ans;
}
void dfs(int u, int pa) {
//cout<<"dfs "<<u<<endl;
for (int i=1; i<=19; ++i)
if (dep[u]>=(1<<i)) fa[u][i] = fa[fa[u][i-1]][i-1];
else break;
for (int i=head[u],v; i; i=e[i].next) {
v = e[i].to;
if (v!=pa) dep[v]=dep[u]+1, fa[v][0]=u, dfs(v, u), mdep=max(mdep, dep[v]);
}
}
int lca(int a, int b) {
if (dep[a]<dep[b]) swap(a, b);
while (dep[a]>dep[b]) a=fa[a][lg[dep[a]-dep[b]]-1];
if (a==b) return a;
for (int i=lg[dep[a]]-1; i>=0; --i)
if (fa[a][i]!=fa[b][i])
a=fa[a][i], b=fa[b][i];
return fa[a][0];
}
signed main()
{
#ifdef DEBUG
freopen("1.in", "r", stdin);
#endif
int a, b, k, t;
n=read();
e = new edge[n*2+10];
for (int i=1,u,v; i<n; ++i) {u=read(); v=read(); add(u, v); add(v, u);}
for (int i=1; i<=n; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
dep[1]=1;
dfs(1, 0);
m=read();
for (int i=1; i<=m; ++i) {
a=read(); b=read(); k=read();
//cout<<"ab: "<<a<<' '<<b<<endl;
t=lca(a, b);
//cout<<"t: "<<t<<endl;
if (!vis[k]) {
sum[k][1]=1;
for (int i=2; i<=mdep; ++i) sum[k][i]=(sum[k][i-1]+qpow(i, k))%mod;
vis[k]=1;
}
printf("%lld\n", ((sum[k][dep[a]-1]-sum[k][dep[t]-1]+sum[k][dep[b]-1]-sum[k][max(dep[t]-2, 0)])%mod+mod)%mod);
}
return 0;
}