codeforces716E/715C 树上点分治
https://codeforces.com/contest/715/problem/C
点分治...有很多细节的点分治题
我用了欧拉函数+费马小定理求的非质数乘法逆元
然后发现题目保证和10互质....不需要这么麻烦,其实直接套扩欧求逆元就行了
这题依然是满足可减性的信息,所以每一步容斥一次就行
#include<bits/stdc++.h> #define endl '\n' #define ll long long #define ull unsigned long long #define fi first #define se second #define mp make_pair #define pii pair<int,int> #define all(x) x.begin(),x.end() #define IO ios::sync_with_stdio(false) #define show(x) cout<<#x<<"="<<x<<endl #define show2(x,y) cout<<#x<<"="<<x<<" "<<#y<<"="<<y<<endl #define show3(x,y,z) cout<<#x<<"="<<x<<" "<<#y<<"="<<y<<" "<<#z<<"="<<z<<endl #define show4(w,x,y,z) cout<<#w<<"="<<w<<" "<<#x<<"="<<x<<" "<<#y<<"="<<y<<" "<<#z<<"="<<z<<endl #define show5(v,w,x,y,z) cout<<#v<<"="<<v<<" "<<#w<<"="<<w<<" "<<#x<<"="<<x<<" "<<#y<<"="<<y<<" "<<#z<<"="<<z<<endl #define showa(a,b) cout<<#a<<'['<<b<<"]="<<a[b]<<endl #define rep(ii,a,b) for(int ii=a;ii<=b;++ii) #define per(ii,a,b) for(int ii=b;ii>=a;--ii) #define forn(i,x) for(int i=head[x];i;i=e[i].next) using namespace std; const int maxn=1e5+10,maxm=2e5+10; const int INF=0x3f3f3f3f; const int mod=1e9+7; const double PI=acos(-1.0); int casn,n,k; ll val[maxn],pw[maxn],phi,inv[maxn],m; ll pow_mod(ll a,ll b,ll c=m,ll ans=1) { while(b) { if(b&1) ans=(a*ans)%c; a=(a*a)%c,b>>=1; } return ans; } ll euler(ll n) { ll res=n,a=n; for(int i=2; i*i<=a; i++) { if(a%i==0) { res=res/i*(i-1); while(a%i==0) a/=i; } } if(a>1) res=res/a*(a-1); return res; } void init(){ phi=euler(m);inv[0]=pw[0]=1; rep(i,1,n){ pw[i]=pw[i-1]*10%m; inv[i]=pow_mod(pw[i],phi-1); } } class graph{public: struct node{int to,next;ll cost;}e[maxm]; int head[maxn],nume,n,sz[maxn],maxt,stree[maxn]; void add(int a,int b,ll c=0){e[++nume]={b,head[a],c};head[a]=nume;} int vis[maxn],num[maxn],all,mid; void getmid(int now=1,int pre=0){ sz[now]=1; for(int i=head[now];i;i=e[i].next){ if(e[i].to==pre||vis[e[i].to]) continue; getmid(e[i].to,now); sz[now]+=sz[e[i].to]; } int tmp=max(sz[now]-1,all-sz[now]); if(maxt>tmp) maxt=tmp,mid=now; }//base ll ans,in[maxn],out[maxn]; map<ll,int> cnt; void init(int n){ this->n=n,nume=1,mid=0; rep(i,1,n) vis[i]=head[i]=0; } ll getout(int now,int pre,int d){ ll sum=0; for(int i=head[now];i;i=e[i].next){ int to=e[i].to; if(to==pre||vis[to]) continue; out[to]=(e[i].cost*pw[d]%m+out[now])%m; sum+=getout(to,now,d+1); } sum+=cnt[out[now]]; ll tmp=(m-in[now])%m*inv[d]%m; if(out[now]==tmp) sum--; return sum; } void getin(int now,int pre,int d){ ll tmp=(m-in[now])%m*inv[d]%m; cnt[tmp]++; for(int i=head[now];i;i=e[i].next){ int to=e[i].to; if(vis[to]||to==pre) continue; in[to]=(in[now]*10%m+e[i].cost)%m; getin(to,now,d+1); } } ll getans(int now,int d=0){ cnt.clear(); in[now]=out[now]=d; getin(now,0,d!=0); return getout(now,0,d!=0); } void divide(int now){ vis[now]=1;ans+=getans(now); for(int i=head[now];i;i=e[i].next){ int to=e[i].to; if(vis[to]) continue; ans-=getans(to,e[i].cost); all=sz[to],maxt=n+1; getmid(to,now);divide(mid); } } void solve(){ ans=0;maxt=all=n; getmid();divide(mid); } }g; int main() {IO; cin>>n>>m; init(); g.init(n); rep(i,2,n){ int a,b;ll c;cin>>a>>b>>c;a++,b++; g.add(a,b,c%m);g.add(b,a,c%m); } g.solve(); cout<<g.ans<<endl; return 0; }