[点分治] Codeforces 715C Digit Tree
题目大意
题目链接
大致题意:给定一棵 \(n(n\leq 10^5)\) 个点的边权为 \(1\sim 9\) 的树,询问有多少有序点对 \((u,v)\) 满足从 \(u\) 到 \(v\) 的路径上的权值组成的数被 \(m\) 整除(像字符串一样拼在一起)。
题解
点分治。
设当前分治到的根为 \(root\),其子树内 \(u\) 的深度为 \(deep[u]\),深度从 \(0\) 开始。设 \(a(u)\) 表示从 \(u\) 到 \(root\) 的路径构成的数字,设 \(b(u)\) 表示从 \(root\) 到 \(u\) 的路径上构成的数字。则对于过 \(root\) 的一条路径 \(u\sim v\),它构成的数字为 \(a(u)\times 10^{deep[v]}+b(v)\)。则 \(a(u)\times 10^{deep[v]}+b(v)\equiv 0\pmod m\)。两边同乘 \(10^{-deep[v]}\),得 \(a(u)+b(v)\times 10^{-deep[v]}\equiv 0 \pmod m\)。移项后得 \(b(v)\times 10^{-deep[v]}\equiv -a(u)\pmod m\)。所以只需要统计 \(b(v)\times 10^{-deep[v]}\) 同余 \(-a(u)\) 的数量即可。使用一个map保存每个 \(-a(u)\) 的数量,然后把所有 \(b(v)\times 10^{-deep[v]}\) 存入一个vector,遍历该vector累加map中对应的 \(-a(u)\) 的数量即可。 注意 \(m\) 不一定是质数,所以此处不能使用费马小定理求 \(10\) 的乘法逆元,而题目规定 \(\gcd(10,m)=1\),所以可以使用扩欧求 \(10\) 的乘法逆元。
Code
#include <bits/stdc++.h>
using namespace std;
#define RG register int
#define LL long long
template<typename elemType>
inline void Read(elemType &T){
elemType X=0,w=0; char ch=0;
while(!isdigit(ch)) {w|=ch=='-';ch=getchar();}
while(isdigit(ch)) X=(X<<3)+(X<<1)+(ch^48),ch=getchar();
T=(w?-X:X);
}
struct Graph{
struct edge{int Next,to,w;};
edge G[200010];
int head[100010];
int cnt;
Graph():cnt(2){}
void clear(int node_num=0){
cnt=2;
if(node_num==0) memset(head,0,sizeof(head));
else fill(head,head+node_num+5,0);
}
void add_edge(int u,int v,int w){
G[cnt].w=w;
G[cnt].to=v;
G[cnt].Next=head[u];
head[u]=cnt++;
}
};
Graph G;
const int maxn=100010;
int Size[maxn],DisA[maxn],DisB[maxn],Deep[maxn];
LL Inv[maxn],Ten[maxn];
vector<LL> B;
map<LL,LL> A;
bool vis[maxn];
int N,Root,CurSize,MaxSize;
LL InvTen,K,Ans;
inline LL mo(LL x){
if(x<0) return (x%K+K)%K;
if(x>=K) return x%K;
return x;
}
LL ex_GCD(LL a,LL b,LL &x,LL &y){
if(b==0){x=1;y=0;return a;}
LL Res=ex_GCD(b,a%b,x,y);
LL temp=x;x=y;y=temp-a/b*y;
return Res;
}
LL inv(LL a,LL p){LL x,y;ex_GCD(a,p,x,y);return x;}
void GetRoot(int u,int fa){
Size[u]=1;
int mx=0;
for(int i=G.head[u];i;i=G.G[i].Next){
int v=G.G[i].to;
if(v==fa || vis[v]) continue;
GetRoot(v,u);
Size[u]+=Size[v];
mx=max(mx,Size[v]);
}
mx=max(mx,CurSize-Size[u]);
if(mx<MaxSize){MaxSize=mx;Root=u;}
}
void GetDis(int u,int fa,LL a,LL b){
++A[mo(-a)];
B.push_back(mo(b*Inv[Deep[u]]));
Size[u]=1;
for(int i=G.head[u];i;i=G.G[i].Next){
int v=G.G[i].to;
if(v==fa || vis[v]) continue;
Deep[v]=Deep[u]+1;
GetDis(v,u,mo(a+Ten[Deep[u]]*G.G[i].w),mo(b*10LL+G.G[i].w));
Size[u]+=Size[v];
}
}
LL Calc(int u,int len){
A.clear();B.clear();
if(!len){Deep[u]=0;GetDis(u,0,0,0);}
else{Deep[u]=1;GetDis(u,0,len%K,len%K);}
LL Res=0;
for(auto x:B) Res+=A[x];
return Res;
}
void Divide(int u){
vis[u]=true;
Ans+=Calc(u,0);
for(int i=G.head[u];i;i=G.G[i].Next){
int v=G.G[i].to;
if(vis[v]) continue;
Ans-=Calc(v,G.G[i].w);
CurSize=MaxSize=Size[v];
Root=0;GetRoot(v,0);Divide(Root);
}
}
int main(){
Read(N);Read(K);
for(int i=1;i<=N-1;++i){
int u,v,w;
Read(u);Read(v);Read(w);
++u;++v;
G.add_edge(u,v,w);
G.add_edge(v,u,w);
}
if(K==1){
cout<<(LL)N*(LL)(N-1)<<endl;
return 0;
}
InvTen=inv(10,K);
Inv[0]=Ten[0]=1;
for(int i=1;i<=N;++i){
Inv[i]=mo(Inv[i-1]*InvTen);
Ten[i]=mo(Ten[i-1]*10LL);
}
CurSize=MaxSize=N;
Root=0;GetRoot(1,0);Divide(Root);
printf("%I64d\n",Ans-N);
return 0;
}