树上跳棋
题目链接
`戳我
\(Solution\)
对于一个点如果能够被跳到当且仅当这个点的深度\(mod\)一次跳的长度等于起始节点\(mod\)一次跳的长度
假设能够被\(p1,p2\)两个点都能到达的点为\(z\)需要满足以下条件
\[dep[z]<=dep[lca]
\]
\[dep[z]\equiv dep[p1]\ (mod \ d1)
\]
\[dep[z]\equiv dep[p2]\ (mod \ d2)
\]
下面那两个同余的限制可以用\(excrt\)(扩展中国剩余定理)算出最小的\(dep[z]\)
然后满足的\(dep[z]\)就是\(dep[z]+k*lcm(d1,d2)\quad k \in N^*\)
通过这个可以算出满足\(dep[z]<=dep[lca]\)最大的\(dep[z]\)即两个棋子的最近的重合节点
对于\(-1\)的情况即需要判断一下是否有解和是否存在\(dep[x]<=dep[lca]\)即可
\(Code\)
#include <bits/stdc++.h>
#define int long long
using namespace std;
typedef long long ll;
int read() {
int x = 0, f = 1;
char c = getchar();
while(c < '0' || c > '9')
f = (c == '-') ? -1 : 1, c = getchar();
while(c >= '0' && c <= '9')
x = x * 10 + c - 48, c = getchar();
return f * x;
}
struct node {
int to, next;
} a[400010];
int head[200010], cnt;
int f[200010][21], dep[200010];
void add(int x, int y) {
a[++cnt].next = head[x];
head[x] = cnt;
a[cnt].to = y;
}
void dfs(int x, int fa) {
f[x][0] = fa;
dep[x] = dep[fa] + 1;
for(int i = 1; i <= 20; i++)
f[x][i] = f[f[x][i - 1]][i - 1];
for(int i = head[x]; i; i = a[i].next) {
int v = a[i].to;
if(v == fa)
continue;
dfs(v, x);
}
}
int lca(int x, int y) {
if(dep[x] > dep[y])
swap(x, y);
for(int i = 20; i >= 0; i--)
if(dep[f[y][i]] >= dep[x])
y = f[y][i];
if(x == y)
return x;
for(int i = 20; i >= 0; i--)
if(f[x][i] != f[y][i])
x = f[x][i], y = f[y][i];
return f[x][0];
}
int jump(int x, int depth) {
for(int i = 20; i >= 0; i--)
if(dep[f[x][i]] >= depth)
x = f[x][i];
return x;
}
int exgcd(int a, int b, int& x, int& y) {
if(!b) {
x = 1, y = 0;
return a;
}
int gcd = exgcd(b, a % b, x, y);
int X = x, Y = y;
x = Y, y = X - a / b * Y;
return gcd;
}
int ksm(int a, int b, int p) {
int ans = 0;
while(b) {
if(b & 1)
ans = (ans + a) % p;
a = (a + a) % p;
b >>= 1;
}
return ans;
}
int m[10], A[10];
int find(int X, int Y, int a, int b) {
A[1] = X, A[2] = Y;
m[1] = a, m[2] = b;
int lcm = m[1], ans = A[1] % m[1];
int c = ((A[2] - ans) % m[2] + m[2]) % m[2], x, y;
int gcd = exgcd(lcm, m[2], x, y);
int k = ksm(x, c / gcd, m[2] / gcd);
ans += lcm * k;
lcm *= m[2] / gcd;
ans = (ans % lcm + lcm) % lcm;
return ans;
}
signed main() {
int n = read(), x, y;
for(int i = 1; i < n; i++)
x = read(), y = read(), add(x, y), add(y, x);
int q = read();
dfs(1, 0);
while(q--) {
int x = read(), d1 = read(), y = read(), d2 = read();
int Lca = lca(x, y), lcm = d1 * d2 / __gcd(d1, d2);
int minx = find(dep[x] % d1, dep[y] % d2, d1, d2);
if(minx == 0)
minx += lcm;
if(minx % d1 != dep[x] % d1 || minx % d2 != dep[y] % d2) {
puts("-1");
continue;
}
if(minx > dep[Lca]) {
puts("-1");
continue;
}
int depth = (dep[Lca] - minx) / lcm * lcm + minx;
cout << jump(x, depth) << "\n";
}
}