ZR#959

ZR#959

解法:

对于一个询问,设路径 $ (u, v) $ 经过的所有边的 $ gcd $ 为 $ g $,这可以倍增求出。
考虑 $ g $ 的所有质因子 $ p_1, p_2, \cdots , p_k $ ,因为 $ g \leq 10^6 $ ,所以 $ k \leq 7 $ 。
则最终的路径的 $ gcd $ 为 $ 1 $,等价于对于每个 $ 1 \leq i \leq k $ ,存在至少一条路径上的边不是 $ p_i $ 的倍数。我们要求 $ l $ 的最小值,即等价于对于每个 $ 1 \leq i \leq k $ ,计算出最长的不满足条件的 $ l′ $,则最终答案即为所有 $ i $ 对应的 $ l′ $ 的最大值加一(无解的情况除外)。
考虑对于某个 $ p_i $ 而言,我们如何求出这样的 $ l′ $ 。我们考虑将所有满足 $ p_i | w $的边拿出来,并只保留这些边。则 $ l′ $ 等价于在这样得到的森林中,经过 $ (u, v) $ 的最长路径。
使用简单的树形DP即可求出某个点向子树方向以及向祖先方向延伸的最长路径,分类讨论即可对于每个 $ (u, v) $ 求出对应的 $ l′ $ 。
接下来考虑无解的情况。事实上,无解等价于刚刚求出的某个 $ l′ $ 和经过 $ (u, v) $ 的最长路径相同。经过 $ (u, v) $ 的最长路径和刚刚是同样的问题,直接对整棵树都做一遍树形 DP 即可。
接下来考虑复杂度。求出每个询问的 $ gcd $ 的复杂度为 $ O(q \log_2 n \log_2 w) $ ,而求出最长路的部分是与边数成线性的,而每条边至多出现 $ 7 $ 次,因此该做法的总复杂度即为 $ O(q \log_2 n \log_2 w + nω(w)) $ 。

CODE:

//9.10补全
#pragma GCC optimize(2)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>

using namespace std;

#define LL long long

const int N = 2e5 + 100;
const int V = 2e6 + 100;
const int M = 18;

struct Edge {
    int to,from;
    int data;
}e[N<<1];

int head[N],cnt,tot,n,q;
int f[M + 1][N],g[M + 1][N];
int prime[V],d[V],deep[N];
bool p[V];

inline void prework() {
    d[1] = 1;
    for(int i = 2 ; i <= V - 1 ; i++) {
        if(!p[i]) prime[++tot] = i,d[i] = i;
        for(int j = 1 ; j <= tot && i * prime[j] < V ; ++j) {
            p[i * prime[j]] = true;
            d[i * prime[j]] = prime[j];
            if(!(i % prime[j])) break;
        }
    }
}
inline void add_edge(int x,int y,int z) {
    e[++cnt].from = y;
    e[cnt].data = z;
    e[cnt].to = head[x];
    head[x] = cnt;
} 
inline int gcd(int a,int b) {
    return !b ? a :gcd(b,a % b);
}
inline void dfs(int v,int fa) {
    for(int i = 1 ; i <= M ; i++) {
        f[i][v] = f[i - 1][f[i - 1][v]];
        g[i][v] = gcd(g[i - 1][v],g[i - 1][f[i - 1][v]]);
    }
    for(int i = head[v] ; i ; i = e[i].to) {
        if(e[i].from == fa) continue;
        f[0][e[i].from] = v;
        g[0][e[i].from] = e[i].data;
        deep[e[i].from] = deep[v] + 1;
        dfs(e[i].from,v);
    }
}
inline int LCA(int x,int y) {
    if(deep[x] != deep[y]) {
        if(deep[x] < deep[y]) swap(x,y);
        int dis = deep[x] - deep[y];
        for(int i = 0 ; i <= M ; i++) {
            if((1 << i) & dis) x = f[i][x];
        }
    }
    if(x == y) return x;
    for(int i = M ; i >= 0 ; i--) {
        if(f[i][x] == f[i][y]) continue;
        x = f[i][x],y = f[i][y];
    }
    return f[0][x];
}
inline int cal(int x,int y) {
    if(deep[x] < deep[y]) swap(x,y);
    int d = deep[x] - deep[y],ans = 0;
    for(int i = 0 ; i <= M ; i++) {
        if((1<<i)&d) {
            ans = gcd(g[i][x],ans);
            x = f[i][x];
        }
    }
    return ans;
}
inline int calc(int x,int y) {
    int l = LCA(x,y);
    return gcd(cal(l,x),cal(l,y));
}
inline int dp(int v,int fa,int d) {
    int ans = 0;
    for(int i = head[v] ; i ; i = e[i].to) {
        if(e[i].from == fa) continue;
        if(e[i].data % d) continue;
        ans = max(ans,dp(e[i].from,v,d)+1);
    }
    return ans;
}
inline int kth(int v,int k) {
    for(int i = 0 ; i <= M ; i++) {
        if((1 << i) & k) v = f[i][v];
    }
    return v;
}
inline int dis(int x,int y) {
    return deep[x] + deep[y] - deep[LCA(x,y)] * 2;
}
inline int read() {
    int x = 0, f = 1; 
    char ch = getchar(); 
    while(ch < '0' || ch > '9') {if (ch == '-')f = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9') {x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
    return x * f; 
}

int main() {
    prework();
    n = read(),q = read();
    for(int i = 1 ; i < n ; i++) {
        int u = read(),v = read(),z = read();
        add_edge(u,v,z);
        add_edge(v,u,z);
    }
    dfs(1,0);
    while(q--) {
        int x = read(),y = read();
        int l = LCA(x,y),g = calc(x,y);
        if(g == 1) {
            printf("%d \n",dis(x,y));
            continue;
        }
        int fx = f[0][x],fy = f[0][y];
        if(l == x) fx = kth(y,deep[y] - deep[x] - 1);
        else if(l == y) fy = kth(x,deep[x] - deep[y] - 1); // 限制不走相同的子树
        int mxdis = dp(x,fx,1) + dp(y,fy,1);
        int ans = 0;
        while(g != 1) {
            int p = d[g];
            ans = max(ans,dp(x,fx,p) + dp(y,fy,p));
            if(ans == mxdis) break;
            while(!(g % p)) g /= p;
        }
        printf("%d\n",ans == mxdis ? -1 : ans + dis(x,y) + 1);
    }
    //system("pause");
    return 0;
}
posted @ 2019-09-09 21:26  西窗夜雨  阅读(198)  评论(0编辑  收藏  举报