题解 苯为

传送门

唔嗯……基环树染色?啊啊,那树点就是直接乘若干个 \(k-1\) 嘛!
给环染色?……容斥一下?
断环为链的话,第一个点有 \(k\) 种选法,剩下的点有 \(k-1\) 种选法
再减去第一个点和最后一个点颜色相同的情况
那么把这两个点合成一个,就是减去 \(f_{n-1}\)
所以

\[f_n=k(k-1)^{n-1}-f_{n-1} \]

啊啊过样例了好耶,交一下!爆零了好耶!
哦,原来 \(n=2\) 的时候不成环要特判 \(f_2=k(k-1)\)
这个 \(f_n\) 怎么快速求远项呢?
把式子展开!变成

\[k\sum\limits_{i=2}^{n-1}(-1)^{n-i+1}(k-1)^i \]

再加减一个 \(f_2\)
然后这个东西可以等比数列求和
然后点分治 + NTT 算每种距离的方案数
然后发现模数是 \(2^{14}\times 3\times 5\times 17\times 101+1\)

然后考虑翻集训队论文

  • 关于图染色/环染色/特殊色多项式:

使用最后一个式子,就可以换根 DP 了
复杂度 \(O(n)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 1000010
#define fir first
#define sec second
#define pb push_back
#define ll long long
#define int128 __int128
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline ll read() {
	ll ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

ll n, A, k;
vector<int> to[N];
const ll mod=421969921, phi=mod-1;
inline ll qpow(ll a, ll b) {assert(b>=0); ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
inline ll qpow(ll a, int b) {assert(b>=0); ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
inline ll qpow(ll a, int128 b) {assert(b>=0); ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

namespace force{
	ll f[N], ans;
	ll qval(int128 n) {
		// if (n<=2) return n==1?k:k*(k-1)%mod;
		// ll ans=0, val=k-1, sqr=val*val%mod, inv=qpow(sqr-1, mod-2);
		// ans=(ans+(n&1?-1:1)*(qpow(val, ((n|1)-2)+2)-qpow(val, 3))*inv)%mod;
		// ans=(ans+(n&1?1:-1)*(qpow(val, (((n-1)>>1)<<1)+2)-sqr)*inv)%mod;
		// ans=(k*ans+(n&1?-1:1)*k%mod*(k-1))%mod;
		// return ans;
		return (qpow(k-1, n)+(n&1?-1:1)*(k-1))%mod;
	}
	ll F(ll len) {
		if (f[len]!=mod+1) return f[len];
		return f[len]=qval(len*(int128)(A+1))*qpow(k-1, (n-len)*(int128)(A+1))%mod;
	}
	void dfs(int u, int fa, int dis) {
		ans=(ans+F(dis))%mod;
		for (auto v:to[u]) if (v!=fa)
			dfs(v, u, dis+1);
	}
	void solve() {
		for (int i=1; i<=n; ++i) f[i]=mod+1;
		for (int s=1; s<=n; ++s) dfs(s, 0, 1);
		printf("%lld\n", (ans%mod+mod)%mod);
	}
}

namespace task1{
	ll ans;
	ll F(ll len) {return qpow(-1, (A+1)*len)*qpow(k-1, (A+1)*(n-len))%mod;}
	void dfs(int u, int fa, int dis) {
		ans=(ans+F(dis))%mod;
		for (auto& v:to[u]) if (v!=fa)
			dfs(v, u, dis+1);
	}
	void solve() {
		for (int s=1; s<=n; ++s) dfs(s, 0, 1);
		ans=(ans*(k-1)+qpow(k-1, (A+1)*n)*n%mod*n)%mod;
		printf("%lld\n", (ans%mod+mod)%mod);
	}
}

namespace task{
	ll f[N], g[N], step, ans;
	void dfs1(int u, int fa) {
		for (auto& v:to[u]) if (v!=fa) {
			dfs1(v, u);
			f[u]=(f[u]+f[v]*step)%mod;
		}
		ans=(ans+f[u])%mod;
	}
	void dfs2(int u, int fa) {
		ans=(ans+g[u])%mod;
		for (auto& v:to[u]) if (v!=fa) {
			g[v]=(g[u]+f[u]-f[v]*step)%mod*step%mod;
			dfs2(v, u);
		}
	}
	void solve() {
		step=qpow(-1, A+1)*qpow(qpow(k-1, A+1), mod-2)%mod;
		ll val=qpow(-1, A+1)*qpow(k-1, (n-1)*(A+1))%mod;
		for (int i=1; i<=n; ++i) f[i]=val;
		dfs1(1, 0); dfs2(1, 0);
		ans=(ans*(k-1)+qpow(k-1, (A+1)*n)*n%mod*n)%mod;
		printf("%lld\n", (ans%mod+mod)%mod);
	}
}

signed main()
{
	freopen("ber.in", "r", stdin);
	freopen("ber.out", "w", stdout);

	n=read(); A=read()%phi; k=read()%mod;
	for (int i=1; i<n; ++i) {
		int x=read(), y=read();
		to[x].pb(y); to[y].pb(x);
	}
	// force::solve();
	task::solve();

	return 0;
}
posted @ 2022-08-02 20:32  Administrator-09  阅读(3)  评论(0编辑  收藏  举报