[NOI2020省选]作业题

题目

点这里看题目。

分析

这道题其实是两道题目。

首先可以娴熟地变换一下柿子:

\[\begin{aligned} \sum_T val(T) &=\sum_T \left(\sum_{e\in T} w_e\times \gcd_{e\in T}w_e\right)\\ &=\sum_{d=1} d\times \left(\sum_T \sum_{e\in T}w_e\left[\gcd_{e\in T}w_e=d\right]\right)\\ \end{aligned} \]

尝试定义:

\[\begin{aligned} f(d)&=\sum_T \sum_{e\in T}w_e\left[\gcd_{e\in T}w_e=d\right]\\ F(n)&=\sum_{n|d} f(d)=\sum_{T} \sum_{e\in T}w_e\left[n\ |\gcd_{e\in T}w_e\right] \end{aligned} \]

然后发现\(F(n)\)相对而言好求一些。知道了\(F\),就可以用莫比乌斯反演或者逆推方法计算出\(f\),就解决了问题。

发现艾弗森括号内的限制可以通过限制边来解决。对于\(F(n)\),我们只用\(n|w_e\)的边\(e\)来计算,即可。

于是考虑一个图中,求出:

\[\sum_T \sum_{e\in T}w_e \]

好像是一个老 trick 了,可是我从没听说过

考虑把边权写成一次函数的形式:\(w\Rightarrow 1+wx\)

然后发现:\((1+w_1x)(1+w_2x)=1+(w_1+w_2)x+w_1w_2x^2\)

忽略二次项,它们的一次项系数加起来了!

然后就可以使用一次函数代替原来的,并在\(\bmod{x^2}\)的意义下计算,利用“乘法”和矩阵树定理计算边权和。

定义这个新的类型的基本形式为\(a+bx\),我们还可以定义“四则运算”:

加法:\((a+bx)+(c+dx)=(a+c)+(b+d)x\)

减法:\((a+bx)-(c+dx)=(a-c)+(b-d)x\)

乘法:\((a+bx)(c+dx)=ac+(bc+ad)x\)

除法:我们只需要计算逆元:\((a+bx)^{-1}=(a^{-1}-ba^{-2}x)\),推导过程不难;

然后就可以枚举所有因子计算了。这里有一个剪枝,即当当前的图不连通的时候,我们直接跳掉。时间复杂度大概是\(O(n^4d)\),其中\(d≈144\)

代码

#include <cstdio>
#include <vector>
#include <iostream>
using namespace std;

const int mod = 998244353;
const int MAXN = 35, MAXM = MAXN * MAXN, MAXV = 2e5 + 5;

template<typename _T>
void read( _T &x )
{
	x = 0;char s = getchar();int f = 1;
	while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
	while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
	x *= f;
}

template<typename _T>
void write( _T x )
{
	if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
	if( 9 < x ){ write( x / 10 ); }
	putchar( x % 10 + '0' );
}

template<typename _T>
_T MAX( const _T a, const _T b )
{
	return a > b ? a : b;
}

int inv( const int );
int qkpow( int, int );

struct linear
{
	int k, b;
	//f(x)=ax+b
	linear() { k = 0, b = 1; }
	linear( const int B ) { k = 0, b = B; }
	linear( const int K, const int B ) { k = K, b = B; }
	
	linear getInv() const
	{
		int tmp = inv( b );
		return linear( ( mod - 1ll * tmp * tmp % mod * k % mod + mod ) % mod, tmp );
	}
	
	linear operator / ( const linear &g ) const { return g * getInv(); }
	linear operator + ( const linear &g ) const { return linear( ( k + g.k ) % mod, ( b + g.b ) % mod ); }
	linear operator - ( const linear &g ) const { return linear( ( k - g.k + mod ) % mod, ( b - g.b + mod ) % mod ); }
	linear operator * ( const linear &g ) const { return linear( ( 1ll * k * g.b % mod + 1ll * b * g.k % mod ) % mod, 1ll * b * g.b % mod ); }

	void operator += ( const linear &g ) { *this = *this + g; }
	void operator -= ( const linear &g ) { *this = *this - g; }
	void operator *= ( const linear &g ) { *this = *this * g; }
	void operator /= ( const linear &g ) { *this = *this / g; }
	void operator ++ () { b = ( b + 1 >= mod ? b + 1 - mod : b + 1 ); }
	
	operator bool() const { return k || b; }
};

vector<int> vec[MAXV];

linear D[MAXN][MAXN], G[MAXN][MAXN], K[MAXN][MAXN];
int fr[MAXM], to[MAXM], w[MAXM];
int F[MAXV];
int fa[MAXN];
int N, M;

int qkpow( int base, int indx )
{
	int ret = 1;
	while( indx )
	{
		if( indx & 1 ) ret = 1ll * ret * base % mod;
		base = 1ll * base * base % mod, indx >>= 1;
	}
	return ret;
}

int inv( const int a ) { return qkpow( a, mod - 2 ); }

int det( linear T[][MAXN], const int n )
{
	linear ans, tmp, inver; int indx;
	for( int i = 1 ; i <= n ; i ++ )
	{
		indx = -1;
		for( int j = i ; j <= n ; j ++ )
			if( T[j][i] )
			{ indx = j; break; }
		if( indx == -1 ) return 0;
		if( indx ^ i ) ans.k *= -1;
		swap( T[i], T[indx] ), inver = T[i][i].getInv();
		for( int j = i + 1 ; j <= n ; j ++ )
			if( T[j][i] )
			{
				tmp = T[j][i] * inver;
				for( int k = i ; k <= n ; k ++ )
					T[j][k] -= T[i][k] * tmp;
			}
		ans *= T[i][i];
	}
	return ans.k;
}

void makeSet() { for( int i = 1 ; i <= N ; i ++ ) fa[i] = i; }
int findSet( const int u ) { return fa[u] = ( fa[u] == u ? u : findSet( fa[u] ) ); }
void unionSet( const int u, const int v ) { fa[findSet( u )] = findSet( v ); }

int main()
{
	int mxv = -1;
	read( N ), read( M );
	for( int i = 1 ; i <= M ; i ++ ) read( fr[i] ), read( to[i] ), read( w[i] );
	for( int i = 1 ; i <= M ; i ++ )
		for( int x = 1 ; x * x <= w[i] ; x ++ )
			if( ! ( w[i] % x ) )
			{
				vec[x].push_back( i );
				if( x * x != w[i] ) vec[w[i] / x].push_back( i );
			}
	for( int i = 1 ; i <= M ; i ++ ) mxv = MAX( mxv, w[i] );
	for( int i = 1, cnt ; i <= mxv ; i ++ )
	{
//		makeSet();
//		for( int j = 0 ; j < vec[i].size() ; j ++ )
//			unionSet( fr[vec[i][j]], to[vec[i][j]] );
//		cnt = 0;
//		for( int j = 1 ; j <= N ; j ++ )
//			if( fa[j] == j ) cnt ++;
//		if( cnt > 1 ) continue;
		if( vec[i].size() < N - 1 ) continue;
		for( int i = 1 ; i <= N ; i ++ )
			for( int j = 1 ; j <= N ; j ++ )
				D[i][j] = G[i][j] = K[i][j] = 0;
		for( int j = 0 ; j < vec[i].size() ; j ++ )
			G[fr[vec[i][j]]][to[vec[i][j]]] += linear( w[vec[i][j]], 1 ),
			G[to[vec[i][j]]][fr[vec[i][j]]] += linear( w[vec[i][j]], 1 );
		for( int i = 1 ; i <= N ; i ++ )
			for( int j = 1 ; j <= N ; j ++ )
				D[i][i] += G[i][j];
		for( int i = 1 ; i <= N ; i ++ )
			for( int j = 1 ; j <= N ; j ++ )
				K[i][j] = D[i][j] - G[i][j];
		F[i] = det( K, N - 1 );
	}
	for( int i = mxv ; i ; i -- )
		for( int j = i << 1 ; j <= mxv ; j += i )
			F[i] = ( F[i] - F[j] + mod ) % mod;
	int ans = 0;
	for( int i = 1 ; i <= mxv ; i ++ ) ans = ( ans + 1ll * F[i] * i % mod ) % mod;
	write( ans ), putchar( '\n' );
	return 0;
}
posted @ 2020-07-04 12:42  crashed  阅读(239)  评论(1编辑  收藏  举报