题解 ABC248G【GCD cost on the tree】

posted on 2022-10-13 23:09:06 | under 题解 | source
updated on 2022-11-22 修正了方法二的代码实现。

problem

\(n\) 个点的树有点权 \(a\),求

\[\sum_{\text{路径 P 在树上}}{\gcd(a_{P_1},a_{P_2},\cdots,a_{P_{|P|}})\cdot |P|}. \]

(原题要求路径 \(P\) 的长度至少为 \(2\),请注意)

solution 1

考虑枚举 \(\operatorname{lca}=k\) 怎么做。

\(k\) 的子树中,爆搜所有路径,存到一个 std::map 里。

考虑对于子树 \(k\) 有多少种不同的 \(\gcd\)?由于这些路径中含有 \(a_k\),意味着 \(\gcd|a_k\),那么一共应该有 \(O(\sigma(a_k))\)\(\gcd\)\(10^5\) 中最大的 \(\sigma\) 大概是 \(90\) 左右(\(2\times 3\times 5\times 7\times 11\times 13\times 17=510510\))。

于是有这么一种想法:维护一个子树的路径的前缀和,每次对于一个子树算出的 std::map 中的值和这个前缀和卷起来(\(O(90^2)\)),然后扔进去前缀和(\(O(90)\))。

点分治会凭空多出一只 \(\log\),这道题直接一遍 dfs,每次像上面这样暴力算一下即可。复杂度 \(O(90^2n\log n)\),包含 \(\gcd\)\(\log\)std::map\(\log\log\)

code

没有讲完。这样写的细节很多。

考虑怎么把两个 std::map 卷起来?你需要记录,当 \(\gcd=d\) 时,有多少条路径,它们的路径总长度是多少。然后使用乘法原理。

在继承到父亲时,还需要注意路径总长度应该加上路径条数。

#include <map>
#include <cstdio>
#include <cstring>
#include <utility>
#include <algorithm>
using namespace std;
typedef long long LL;
const int P=998244353;
typedef pair<int,int> node;
int gcd(int x,int y){return !y?x:gcd(y,x%y);}
template<int N,int M,class T=int> struct graph{
    int head[N+10],nxt[M*2+10],cnt;
    struct edge{
        int u,v;T w;
        edge(int u=0,int v=0,T w=0):u(u),v(v),w(w){}
    } e[M*2+10];
    graph(){memset(head,cnt=0,sizeof head);}
    edge operator[](int i){return e[i];}
    void add(int u,int v,T w=0){e[++cnt]=edge(u,v,w),nxt[cnt]=head[u],head[u]=cnt;}
    void link(int u,int v,T w=0){add(u,v,w),add(v,u,w);}
};
int n,a[100010],ans;
graph<100010,100010> g;
node o(node a,node b){return node((a.first+b.first)%P,(a.second+b.second)%P);}
LL calc(map<int,node> &a,map<int,node> &b){
	LL res=0; 
//	printf("a:");for(pair<int,node> x: a) printf("(%d,{%d,%d}),",x.first,x.second.first,x.second.second);puts("");
//	printf("b:");for(pair<int,node> x: b) printf("(%d,{%d,%d}),",x.first,x.second.first,x.second.second);puts("");
	for(pair<int,node> x: a){
		for(pair<int,node> y: b){
			res=(res+1ll*gcd(x.first,y.first)*(1ll*x.second.second*y.second.first%P+1ll*x.second.first*y.second.second%P+1ll*x.second.second*y.second.second%P)%P)%P;
		}
	}
	return res;
}
void merge(map<int,node> &a,map<int,node> &b){for(pair<int,node> y: b) a[y.first]=o(a[y.first],y.second);}
map<int,node> add(map<int,node> a,int w){
	map<int,node> r;
	for(pair<int,node> x: a){
		r[gcd(x.first,w)]=o(r[gcd(x.first,w)],node(x.second.first+x.second.second,x.second.second));
	}
	return r;
}
map<int,node> dfs(int u,int fa=0){
	map<int,node> tot={{a[u],{0,1}}};
	for(int i=g.head[u];i;i=g.nxt[i]){
		int v=g[i].v; if(v==fa) continue;
		map<int,node> res=add(dfs(v,u),a[u]);
//		printf(">u=%d,v=%d\n",u,v);
		ans=(ans+calc(tot,res))%P;
//		printf("%lld\n",ans);
		merge(tot,res);
	}
	return tot;
}
int main(){
//	#ifdef LOCAL
//	 	freopen("input.in","r",stdin);
//	#endif
	scanf("%d",&n);
	for(int i=1;i<=n;i++) scanf("%d",&a[i]);
	for(int i=1,u,v;i<n;i++) scanf("%d%d",&u,&v),g.link(u,v);
	dfs(1,0);
	printf("%d\n",ans);
	return 0;
}

solution 2

枚举 \(\gcd=d\)

考虑求出 \(g_i\) 表示路径 \(\gcd\)\(i\) 的倍数的路径长度和。

考虑经过一条边,若要保证 \(\gcd\)\(i\) 的倍数,必有 \(i|\gcd(a_u,a_v)\)

对每个因数建一棵树,在因数 \(d\) 的树中放入 \(d|\gcd(a_u,a_v)\) 的所有边,对每个因数的树单独 dfs,求出其路径长度和。

现在有 \(g_i\),考虑容斥:

\[f_i=g_i-\sum_{i|j,i\neq j}f_j. \]

倒序算出 \(f_i\) 即为 \(\gcd\) 恰好为 \(i\) 的路径长度和。

复杂度:时间 \(O(nd+n\log n)\),空间 \(O(nd)\),其中 \(d=\max_{i\leq 10^5}\sigma(i)=128\)。枚举因子部分可以预处理,或者根号也行。

code

#include <map>
#include <vector>
#include <cstdio>
#include <cstring>
#include <utility>
#include <cassert>
#include <algorithm>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr,##__VA_ARGS__)
#else
#define debug(...) void(0)
#endif
typedef long long LL;
const int P=998244353;
template<int N> struct siever{
	vector<int> d[N+10];
	siever(){
		for(int i=1;i<=N;i++){
			for(int j=i;j<=N;j+=i) d[j].push_back(i);
		}
	}
};
template<int N,int M,class T=int> struct graph{
    int head[N+10],nxt[M*2+10],cnt;
    int vis[N+10],tag; 
    struct edge{
        int u,v;T w;
        edge(int u=0,int v=0,T w=0):u(u),v(v),w(w){}
    } e[M*2+10];
    graph(){memset(head,cnt=0,sizeof head),memset(vis,tag=0,sizeof vis);}
    edge&operator[](int i){return e[i];}
    void add(int u,int v,T w=0){
		if(vis[u]!=tag) vis[u]=tag,head[u]=0;
		e[++cnt]=edge(u,v,w),nxt[cnt]=head[u],head[u]=cnt;
	}
    void link(int u,int v,T w=0){add(u,v,w),add(v,u,w);}
};
int gcd(int x,int y){return !y?x:gcd(y,x%y);}
int n,a[100010];
LL ans,cnt,f[100010];
int vis[100010];
graph<100010,100010> t,g;
vector<int> e[100010];
siever<100010> s; 
LL calc(pair<LL,LL> a,pair<LL,LL> b){
	return ((a.first*b.second+b.first*a.second-a.second*b.second%P+P))%P;
}
pair<LL,LL> dfs(int u,int fa,LL &ans){
	vis[u]=g.tag;
	pair<LL,LL> tot={1,1};
	for(int i=g.head[u];i;i=g.nxt[i]){
		int v=g[i].v; if(v==fa) continue;
		pair<LL,LL> res=dfs(v,u,ans);
		res={(res.first+res.second)%P,res.second};
		ans=(ans+calc(tot,res))%P;
		tot={(tot.first+res.first)%P,(tot.second+res.second)%P};
	}
	return tot;
}
void dp(int d){
	g.tag++,g.cnt=0;
	for(int i:e[d]) g.link(t[i].u,t[i].v),debug("link (%d,%d)\n",t[i].u,t[i].v);
	for(int i:e[d]) if(vis[t[i].u]!=g.tag) dfs(t[i].u,0,f[d]);
	for(int i:e[d]) if(vis[t[i].v]!=g.tag) dfs(t[i].v,0,f[d]);
	if(f[d]) debug("f[%d]=%lld\n",d,f[d]);
}
int main(){
//	#ifdef LOCAL
//	 	freopen("input.in","r",stdin);
//	#endif
	scanf("%d",&n);
	for(int i=1;i<=n;i++) scanf("%d",&a[i]);
	for(int i=1,u,v;i<n;i++){
		scanf("%d%d",&u,&v);
		t.add(u,v);
		for(int d:s.d[gcd(a[u],a[v])]) e[d].push_back(i),debug("(%d,%d) belongs to tree %d\n",u,v,d); 
	}
	for(int i=1;i<=1e5;i++) dp(i);
	for(int i=1e5;i>=1;i--){
		for(int j=i+i;j<=1e5;j+=i) f[i]=(f[i]-f[j]+P)%P;
		ans=(ans+f[i]*i)%P;
	}
	printf("%lld\n",ans);
	return 0;
}

posted @ 2022-11-22 16:23  caijianhong  阅读(48)  评论(0编辑  收藏  举报