codeforces 715 C. Digit Tree
给定一棵有$n$个点的树,有边权($0 \sim 9$)
给定$m$,求有多少个点对$(u,v)$,满足从$u$到$v$所经过的边权形成的整数是$m$的倍数
保证$\gcd(10,m)=1$
$1 \le n \le 10^5$
$1 \le m \le 10^9$
一看就是点分治毒瘤题……
欧拉定理告诉我们,如果$\gcd(a,p)=1$,那么$a^{\phi(p)} \bmod p=1$,即$a^{-1} \bmod p=a^{\phi(p)-1} \bmod p$
然后就没啥细节了……
1 #include <bits/stdc++.h> 2 using namespace std; 3 const int N = 1e5 + 10; 4 int n, m; 5 6 int sz[N], f[N] = { 0x3f3f3f3f }, root, sizeses, ban[N]; 7 struct E { int to, w; }; 8 vector<E> g[N]; 9 10 void getsz(int u, int fa) { 11 sz[u] = 1, f[u] = 0; 12 for(auto e: g[u]) { 13 int v = e.to; 14 if(ban[v] || v == fa) continue; 15 getsz(v, u); 16 sz[u] += sz[v]; 17 f[u] = max(f[u], sz[v]); 18 } 19 f[u] = max(f[u], sizeses - sz[u]); 20 if(f[u] < f[root]) root = u; 21 } 22 23 map<int, int> cnt; 24 25 int pw(int a, int b) { 26 int r = 1; 27 for( ; b ; b /= 2, a = 1ll * a * a % m) if(b % 2 == 1) r = 1ll * r * a % m; 28 return r % m; 29 } 30 31 int inv[N], fac[N]; 32 33 long long ans; 34 35 int getphi(int x) { 36 int r = x; 37 for(int i = 2 ; 1ll * i * i <= x ; ++ i) { 38 if(x % i == 0) r = r / i * (i - 1); 39 while(x % i == 0) x /= i; 40 } 41 if(x > 1) r = r / x * (x - 1); 42 return r; 43 } 44 45 void dfs(int u, int fa, int len, int type, int dep, int fafa) { 46 47 if(fafa == 1) { 48 ans += len == 0; 49 } 50 51 if(type == 0) { 52 int x = - len; x = (x % m + m) % m; 53 int val = 1ll * x * inv[dep] % m; 54 ans += cnt[val]; 55 } else { 56 ++ cnt[len]; 57 } 58 59 for(auto e: g[u]) { 60 int v = e.to; 61 if(ban[v] || v == fa) continue; 62 if(type == 0) 63 dfs(v, u, (1ll * len * 10 + e.w) % m, type, dep + 1, fafa); 64 else 65 dfs(v, u, (1ll * len + 1ll * e.w * fac[dep]) % m, type, dep + 1, fafa); 66 } 67 } 68 69 void sol(int u) { 70 ban[u] = 1; 71 for(int fafa = 1 ; fafa <= 2 ; ++ fafa) { 72 cnt.clear(); 73 for(auto e: g[u]) { 74 int v = e.to; 75 if(ban[v]) continue; 76 dfs(v, u, e.w % m, 0, 1, fafa); 77 dfs(v, u, e.w % m, 1, 1, fafa); 78 } 79 80 reverse(g[u].begin(), g[u].end()); 81 } 82 for(auto e: g[u]) { 83 int v = e.to; 84 if(ban[v]) continue; 85 root = 0, sizeses = sz[v], getsz(v, 0); 86 sol(root); 87 } 88 } 89 90 int main() { 91 scanf("%d%d", &n, &m); 92 93 int p = getphi(m); 94 inv[0] = 1, fac[0] = 1; 95 for(int i = 1 ; i <= n ; ++ i) fac[i] = 1ll * fac[i - 1] * 10 % m, inv[i] = pw(fac[i], p - 1); 96 for(int i = 1, u, v, w ; i < n ; ++ i) { 97 scanf("%d%d%d", &u, &v, &w); 98 ++ u, ++ v, w %= m; 99 g[u].push_back((E) { v, w }); 100 g[v].push_back((E) { u, w }); 101 } 102 sizeses = n, root = 0, getsz(1, 0); 103 sol(root); 104 printf("%lld\n", ans); 105 }