题解 P6624 【[省选联考 2020 A 卷] 作业题】
好题啊,真的好题啊!
莫比乌斯反演+矩阵树定理(甚至还算加了一点点的生成函数的原理)
【分析】
根据题意,求对所有的树 \(T\) 的值:
\(\displaystyle ans=\sum_{T}val(T)=\sum_{T}(\sum_{i=1}^{n-1}w_{e_i})\gcd_{i=1}^{n-1} w_{e_i}\)
\(e_i\) 表示树 \(T\) 的边
由欧拉反演得:\(\displaystyle n=\sum_{d\mid n}\boldsymbol \varphi(d)\)
因此 \(\displaystyle \gcd_{i=1}^{n-1} w_{e_i}=\sum_{d\mid \gcd_{i=1}^{n-1}w_{e_i}}\boldsymbol \varphi(d)=\sum_{\forall i\to d\mid w_{e_i}}\boldsymbol \varphi(d)\)
因此 \(\displaystyle ans=\sum_T(\sum_{i=1}^{n-1}w_{e_i})\sum_{\forall i\to d\mid w_{e_i}}\boldsymbol \varphi(d)\)
调换枚举顺序得 \(\displaystyle ans=\sum_{d\geq 1}\boldsymbol \varphi(d)\sum_T[\forall i\to d\mid w_{e_i}](\sum_{i=1}^{n-1}w_{e_i})\)
于是对于每个 \(w_{e_i}\) ,我们将它塞进它各因数的 vector
中。之后,每个数 \(d\) 的 vector
中,存着所有边权为它倍数的边
现考虑给定一个 vector
,里面存着若干边,用这些边生成树的权值和为多少,即 \(\displaystyle f(d)=\sum_{T\in v_d}(\sum_{i=1}^{n-1}w_{e_i})\)
\(\because \displaystyle \sum_{i=1}^{n-1}w_{e_i}=\sum_{i=1}^{n-1}w_{e_i}\cdot 1\)
若令 \(P_i(x)=1+w_{e_i}x\)
则 \(\displaystyle \sum_{i=1}^{n-1}w_{e_i}\cdot 1=\sum_{c_1+c_2+c_3+\cdots+c_{n-1}=1}\prod_{i=1}^{n-1}P_i[c_i]\)
后面这玩意儿本质上就是卷积,而且还是 \(1\) 次项的系数
\(\therefore \displaystyle \sum_{i=1}^{n-1}w_{e_i}=(\prod_{i=1}^{n-1}P_i(x))[1]\)
\(\therefore \displaystyle f(d)=\sum_{T\in v_d}((\prod_{i=1}^{n-1}P_i(x))[1])=(\sum_{T\in v_d}\prod_{i=1}^{n-1}P_i(x))[1]\)
于是这就构成了矩阵树的形式。我们将 vector
中的边,构造成对应的一次函数,放进矩阵中,求一下行列式在模 \(x^2\) 意义下的值。再把一次项系数返回,就是 \(f(d)\) 的结果
消元方法见下
这样一来,\(\displaystyle ans=\sum_{d\geq 1}\boldsymbol \varphi(d)f(d)\)
如果每个 \(d\) 都求解,复杂度将达到 \(O(m\cdot n^3)\) ,其中 \(m\) 为边的最大值。复杂度直接爆炸。
考虑将 vector
中的边跑一下并查集。如果最后没有并成一个连通分量,那一定不会生成树,贡献为 \(0\) ,直接不跑高斯消元。
这样的剪枝相当于额外增加一个 \(O(mn\cdot \alpha(n))\) ,但感觉是能扣除一大部分高斯消元。试试能不能过(最后当然是过了)。
简单讲一下多项式怎么在模 \(x^2\) 意义下高斯消元。将 \(P(x)=a+bx\) 记为 \((a,b)\) 则:
\((a,b)\pm(c,d)=(a\pm c,b\pm d)\)
\((a+bx)(c+dx)=ac+(ad+bc)x+bdx^2\equiv ac+(ad+bc)x\pmod {x^2}\)
即 \((a,b)\cdot (c,d)=(ac,ad+bc)\)
若 \((a+bx)(c+dx)\equiv 1\pmod {x^2}\)
则 \(ac=1,ad+bc=0\)
故 \(c=a^{-1},d=-(bc)\cdot a^{-1}=-bc^2\)
因此 \((a,b)^{-1}=(a^{-1},-a^{-2}b)\)
故 \((a,b)/(c,d)=(a,b)\cdot (c,d)^{-1}\)
这样一来就可以重载运算符了:
typedef pair<int,int> pii;
#define fi first
#define se second
inline pii operator ~ (const pii &p) { ll tmp=inv(p.fi); return pii(tmp,dis(0,tmp*tmp%MOD*p.se%MOD)); }
//逆元
inline pii operator - (const pii &p) { return pii( dis(0,p.fi) , dis(0,p.se) ); }
inline pii operator + (const pii &a,const pii &b) { return pii( add(a.fi,b.fi) , add(a.se,b.se) ); }
inline pii operator += (pii &a,const pii &b) { return a=a+b; }
inline pii operator - (const pii &a,const pii &b) { return pii( dis(a.fi,b.fi) , dis(a.se,b.se) ); }
inline pii operator -= (pii &a,const pii &b) { return a=a-b; }
inline pii operator * (const pii &a,const pii &b) { return pii( (ll)a.fi*b.fi%MOD , ((ll)a.fi*b.se+(ll)a.se*b.fi)%MOD ); }
inline pii operator *= (pii &a,const pii &b) { return a=a*b; }
inline pii operator / (const pii &a,const pii &b) { return a*(~b); }
inline pii operator /= (pii &a,const pii &b) { return a=a/b; }
【代码】
最后奉上本蒟蒻的代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
#define fi first
#define se second
const int MOD=998244353,MAXN=32,Lim=152501;
inline ll fpow(ll a,ll x) { ll ans=1; for(;x;x>>=1,a=a*a%MOD) if(x&1) ans=ans*a%MOD; return ans; }
//快速幂
inline ll inv(ll a) { return fpow(a,MOD-2); }
inline int add(int a,int b) { return (a+b>=MOD)?(a+b-MOD):(a+b); }
inline int dis(int a,int b) { return (a-b<0)?(a-b+MOD):(a-b); }
//求余意义下的逆元、加减法
inline pii operator ~ (const pii &p) { ll tmp=inv(p.fi); return pii(tmp,dis(0,tmp*tmp%MOD*p.se%MOD)); }
//逆元
inline pii operator - (const pii &p) { return pii( dis(0,p.fi) , dis(0,p.se) ); }
inline pii operator + (const pii &a,const pii &b) { return pii( add(a.fi,b.fi) , add(a.se,b.se) ); }
inline pii operator += (pii &a,const pii &b) { return a=a+b; }
inline pii operator - (const pii &a,const pii &b) { return pii( dis(a.fi,b.fi) , dis(a.se,b.se) ); }
inline pii operator -= (pii &a,const pii &b) { return a=a-b; }
inline pii operator * (const pii &a,const pii &b) { return pii( (ll)a.fi*b.fi%MOD , ((ll)a.fi*b.se+(ll)a.se*b.fi)%MOD ); }
inline pii operator *= (pii &a,const pii &b) { return a=a*b; }
inline pii operator / (const pii &a,const pii &b) { return a*(~b); }
inline pii operator /= (pii &a,const pii &b) { return a=a/b; }
//重载运算符
struct Matrix{
pii Num[MAXN][MAXN];
inline pii det(int N){//求行列式
pii res=pii(1,0);
for(register int i=1;i<=N;++i){
int pos=i;
while(pos<=N&&Num[pos][i]==pii(0,0)) ++pos;
if(pos>N) return pii(0,0);
if(pos!=i) swap(Num[pos],Num[i]),res=-res;
res*=Num[i][i];
for(register int j=N;j>=i;j--) Num[i][j]/=Num[i][i];
for(register int k=i+1;k<=N;++k)
for(register int j=N;j>=i;--j)
Num[k][j]-=Num[i][j]*Num[k][i];
}
return res;
}
}M;
struct Edge{
int u,v,w;
Edge(int u_=0,int v_=0,int w_=0):u(u_),v(v_),w(w_) {}
};
int n,Pa[MAXN],CntSet;
int find(int u) { return (u==Pa[u])?u:(Pa[u]=find(Pa[u])); }
inline void merge(int u,int v){
if(find(u)==find(v)) return ;
Pa[find(u)]=find(v);
--CntSet;
}
//并查集相关操作
inline int calc(const vector<Edge> &v){
//给定 vector ,求答案
if(v.size()<n-1) return 0;
//一定不成树
iota(Pa+1,Pa+n+1,1); CntSet=n;
memset(M.Num,0,sizeof(M.Num));
//并查集和矩阵初始化
for(const auto &e : v)
merge(e.u,e.v),
M.Num[e.u][e.v]-=pii(1,e.w),
M.Num[e.v][e.u]-=pii(1,e.w),
M.Num[e.u][e.u]+=pii(1,e.w),
M.Num[e.v][e.v]+=pii(1,e.w);
if(CntSet!=1) return 0;//无法生成树
return M.det(n-1).se;
}
vector<Edge> e[Lim+1];
ll fc[Lim+1],phi[Lim+1],prime[Lim+1],cntprime;
inline void sieve(){
//欧拉筛积性函数
phi[1]=1;
for(int i=2;i<=Lim;i++){
if(!fc[i]) phi[i]=(fc[i]=prime[++cntprime]=i)-1;
for(int j=1;j<=cntprime;j++)
if(i*prime[j]>Lim) break;
else if(fc[i]>prime[j]) fc[i*prime[j]]=prime[j],phi[i*prime[j]]=phi[i]*(prime[j]-1);
else{
fc[i*prime[j]]=prime[j];
phi[i*prime[j]]=phi[i]*prime[j];
break;
}
}
}
inline void pre(){
int m;
cin>>n>>m;
for(int i=1,u,v,w;i<=m;i++){
cin>>u>>v>>w;
for(int d=1;d*d<=w;d++) if(w%d==0) {
e[d].push_back( Edge(u,v,w) );
if(w/d!=d) e[w/d].push_back( Edge(u,v,w) );
}
//给因数打上标记
}
}
int main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
pre();
sieve();
ll ans=0;
for(int i=1;i<=Lim;i++) ans+=phi[i]*calc(e[i])%MOD;
ans%=MOD;
cout<<ans;
cout.flush();
return 0;
}