题解 苯为
唔嗯……基环树染色?啊啊,那树点就是直接乘若干个 \(k-1\) 嘛!
给环染色?……容斥一下?
断环为链的话,第一个点有 \(k\) 种选法,剩下的点有 \(k-1\) 种选法
再减去第一个点和最后一个点颜色相同的情况
那么把这两个点合成一个,就是减去 \(f_{n-1}\)
所以
\[f_n=k(k-1)^{n-1}-f_{n-1}
\]
啊啊过样例了好耶,交一下!爆零了好耶!
哦,原来 \(n=2\) 的时候不成环要特判 \(f_2=k(k-1)\) 啊
这个 \(f_n\) 怎么快速求远项呢?
把式子展开!变成
\[k\sum\limits_{i=2}^{n-1}(-1)^{n-i+1}(k-1)^i
\]
再加减一个 \(f_2\)
然后这个东西可以等比数列求和
然后点分治 + NTT 算每种距离的方案数
然后发现模数是 \(2^{14}\times 3\times 5\times 17\times 101+1\)
然后考虑翻集训队论文
- 关于图染色/环染色/特殊色多项式:
使用最后一个式子,就可以换根 DP 了
复杂度 \(O(n)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 1000010
#define fir first
#define sec second
#define pb push_back
#define ll long long
#define int128 __int128
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline ll read() {
ll ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
ll n, A, k;
vector<int> to[N];
const ll mod=421969921, phi=mod-1;
inline ll qpow(ll a, ll b) {assert(b>=0); ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
inline ll qpow(ll a, int b) {assert(b>=0); ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
inline ll qpow(ll a, int128 b) {assert(b>=0); ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
namespace force{
ll f[N], ans;
ll qval(int128 n) {
// if (n<=2) return n==1?k:k*(k-1)%mod;
// ll ans=0, val=k-1, sqr=val*val%mod, inv=qpow(sqr-1, mod-2);
// ans=(ans+(n&1?-1:1)*(qpow(val, ((n|1)-2)+2)-qpow(val, 3))*inv)%mod;
// ans=(ans+(n&1?1:-1)*(qpow(val, (((n-1)>>1)<<1)+2)-sqr)*inv)%mod;
// ans=(k*ans+(n&1?-1:1)*k%mod*(k-1))%mod;
// return ans;
return (qpow(k-1, n)+(n&1?-1:1)*(k-1))%mod;
}
ll F(ll len) {
if (f[len]!=mod+1) return f[len];
return f[len]=qval(len*(int128)(A+1))*qpow(k-1, (n-len)*(int128)(A+1))%mod;
}
void dfs(int u, int fa, int dis) {
ans=(ans+F(dis))%mod;
for (auto v:to[u]) if (v!=fa)
dfs(v, u, dis+1);
}
void solve() {
for (int i=1; i<=n; ++i) f[i]=mod+1;
for (int s=1; s<=n; ++s) dfs(s, 0, 1);
printf("%lld\n", (ans%mod+mod)%mod);
}
}
namespace task1{
ll ans;
ll F(ll len) {return qpow(-1, (A+1)*len)*qpow(k-1, (A+1)*(n-len))%mod;}
void dfs(int u, int fa, int dis) {
ans=(ans+F(dis))%mod;
for (auto& v:to[u]) if (v!=fa)
dfs(v, u, dis+1);
}
void solve() {
for (int s=1; s<=n; ++s) dfs(s, 0, 1);
ans=(ans*(k-1)+qpow(k-1, (A+1)*n)*n%mod*n)%mod;
printf("%lld\n", (ans%mod+mod)%mod);
}
}
namespace task{
ll f[N], g[N], step, ans;
void dfs1(int u, int fa) {
for (auto& v:to[u]) if (v!=fa) {
dfs1(v, u);
f[u]=(f[u]+f[v]*step)%mod;
}
ans=(ans+f[u])%mod;
}
void dfs2(int u, int fa) {
ans=(ans+g[u])%mod;
for (auto& v:to[u]) if (v!=fa) {
g[v]=(g[u]+f[u]-f[v]*step)%mod*step%mod;
dfs2(v, u);
}
}
void solve() {
step=qpow(-1, A+1)*qpow(qpow(k-1, A+1), mod-2)%mod;
ll val=qpow(-1, A+1)*qpow(k-1, (n-1)*(A+1))%mod;
for (int i=1; i<=n; ++i) f[i]=val;
dfs1(1, 0); dfs2(1, 0);
ans=(ans*(k-1)+qpow(k-1, (A+1)*n)*n%mod*n)%mod;
printf("%lld\n", (ans%mod+mod)%mod);
}
}
signed main()
{
freopen("ber.in", "r", stdin);
freopen("ber.out", "w", stdout);
n=read(); A=read()%phi; k=read()%mod;
for (int i=1; i<n; ++i) {
int x=read(), y=read();
to[x].pb(y); to[y].pb(x);
}
// force::solve();
task::solve();
return 0;
}