[题解] CF715C Digit Tree

[题解] CF715C Digit Tree

点分治的一道好题。

难点在于去重,也就是 \(calc\) 函数。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <map>
#include <utility>
using namespace std;
const int maxn = 100000 + 100;
int head[maxn],cnt=0;
#define LL long long
struct edge{
	int to,nxt,w;
}e[maxn<<1];
inline void link(int u,int v,int w){
	e[++cnt].to=v;e[cnt].nxt=head[u];head[u]=cnt;e[cnt].w=w;
}
LL power[maxn];
bool vis[maxn];
LL mx[maxn],sz[maxn],rt,S;
void exgcd(LL a,LL b,LL &x,LL &y){
	if(!b){
		x=1;y=0;return ;
	}
	exgcd(b,a%b,y,x);
	y-=(a/b)*x;
}
LL inv(LL a,LL P){
	LL x,y;
	exgcd(a,P,x,y);
	x=(x%P+P)%P;
	return x;
}
void find_rt(int u,int fa){
	sz[u]=1;mx[u]=0;
	for(int i=head[u],v;i;i=e[i].nxt){
		if(vis[v=e[i].to] || v==fa)continue;
		find_rt(v,u);
		sz[u]+=sz[v];
		mx[u]=max(mx[u],sz[v]);
	}
	mx[u]=max(mx[u],S-sz[u]);
	if(mx[rt]>mx[u])rt=u;
}
map <LL,LL> mp;
pair <LL,LL> digit[maxn];
LL num=0;
LL n,m;
void dfs(int u,int fa,LL d1,LL d2,LL de){//正着 VS 反着
	if(de>=0)mp[d1]++,digit[++num]=make_pair(d2,de);
	for(int i=head[u],v;i;i=e[i].nxt){
		if(vis[v=e[i].to] || v==fa)continue;
		int w=e[i].w;
		int d3=(d1+w*power[de+1])%m;
		int d4=(d2*10LL+w)%m;
		dfs(v,u,d3,d4,de+1);
	}
}
LL calc(LL u,LL d){
	mp.clear();
	num=0;
	LL res=0;
	if(!d)dfs(u,0,0,0,-1);
	else dfs(u,0,d%m,d%m,0);
	for(int i=1;i<=num;i++){
		LL tmp=(-digit[i].first*inv(power[digit[i].second+1],m)%m+m)%m;
		if(mp.find(tmp)!=mp.end())res+=mp[tmp];
		if(!d && !digit[i].first)res++;
	}
	if(!d)res+=mp[0];
	return res;
}
LL ans=0,kase=0;
void solve(int u){
	vis[u]=true;
	//printf("%lld rt: %lld\n",++kase,rt);
	ans+=calc(u,0);
	for(int i=head[u],v;i;i=e[i].nxt){
		if(vis[v=e[i].to])continue;
		ans-=calc(v,e[i].w);
		S=sz[v],mx[rt=0]=n;
		find_rt(v,0);
		solve(rt);
	}
}
int main(){
	scanf("%lld%lld",&n,&m);
	for(int i=1,u,v,w;i<n;i++){
		scanf("%d%d%d",&u,&v,&w);
		link(u+1,v+1,w);link(v+1,u+1,w);
	}
	power[0]=1;
	for(int i=1;i<=n;i++)power[i]=power[i-1]*10%m;
	//for(int i=1;i<=n;i++)printf("power[%d]: %lld\n",i,power[i]);
	mx[rt]=S=n;
	find_rt(1,0);solve(rt);
	printf("%lld\n",ans);
	return 0;
}
posted @ 2021-08-12 17:17  ¶凉笙  阅读(29)  评论(0编辑  收藏  举报