题解 [CF1654D] Potion Brewing Class
因为想了一个小时所以还是写一下
首先有一个直接树形 DP 的思路,但是数太大了处理不了
那么考虑用一个数表示出其它所有数,这些数都必须是整数
令单位 1 为 \(t\),根节点为 \(kt\)
那么将每个节点表示为 \(\frac{y}{x}\),发现 \(k\) 必须是所有 \(x\) 的 \(\tt lcm\)
那么 \(\tt lcm\) 即为所有质因子指数取 max
dfs 一遍维护 \(\frac{y}{x}\) 每个质因子的指数,那么同时维护这个数组的历史最大值就好了
最后一遍 dfs 用根节点的答案求出每个点的值即可
复杂度 \(O(n\log n)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 200010
#define ll long long
#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n;
bool npri[N];
int head[N], ecnt;
const ll mod=998244353;
int a[N], his[N], inv[N], base, ans;
int pri[N], low[N], lowp[N], lowc[N], pcnt;
struct edge{int to, next, x, y;}e[N<<1];
inline void add(int s, int t, int x, int y) {e[++ecnt]={t, head[s], x, y}; head[s]=ecnt;}
inline int gcd(int a, int b) {return !b?a:gcd(b, a%b);}
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
void dfs1(int u, int fa) {
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa) continue;
for (int t=e[i].x; t>1; t/=lowp[t])
a[low[t]]+=lowc[t], his[low[t]]=max(his[low[t]], a[low[t]]);
for (int t=e[i].y; t>1; t/=lowp[t])
a[low[t]]-=lowc[t], his[low[t]]=max(his[low[t]], a[low[t]]);
dfs1(v, u);
for (int t=e[i].x; t>1; t/=lowp[t])
a[low[t]]-=lowc[t], his[low[t]]=max(his[low[t]], a[low[t]]);
for (int t=e[i].y; t>1; t/=lowp[t])
a[low[t]]+=lowc[t], his[low[t]]=max(his[low[t]], a[low[t]]);
}
}
void dfs2(int u, int fa, int now) {
ans=(ans+now)%mod;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa) continue;
dfs2(v, u, now*e[i].y%mod*inv[e[i].x]%mod);
}
}
signed main()
{
int T=read();
inv[0]=inv[1]=1;
for (int i=2; i<N; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
for (int i=2; i<N; ++i) {
if (!npri[i]) pri[++pcnt]=low[i]=lowp[i]=i, lowc[i]=1;
for (int j=1,x; j<=pcnt&&i*pri[j]<N; ++j) {
npri[x=i*pri[j]]=1;
if (!(i%pri[j])) {
low[x]=pri[j];
lowp[x]=lowp[i]*pri[j];
lowc[x]=lowc[i]+1;
break;
}
else low[x]=lowp[x]=pri[j], lowc[x]=1;
}
}
while (T--) {
n=read(); ecnt=ans=0; base=1;
for (int i=1; i<=n; ++i) head[i]=-1;
for (int i=1; i<=n; ++i) a[i]=his[i]=0;
for (int i=1,u,v,x,y,t; i<n; ++i) {
u=read(); v=read(); x=read(); y=read();
t=gcd(x, y); x/=t; y/=t;
add(u, v, x, y); add(v, u, y, x);
}
dfs1(1, 0);
for (int i=1; i<=n; ++i) base=base*qpow(i, his[i])%mod;
dfs2(1, 0, base);
printf("%lld\n", ans);
}
return 0;
}