HDU 4812 D Tree

HDU 4812 

思路:

点分治

先预处理好1e6 + 3以内到逆元

然后用map 映射以分治点为起点的链的值a 成他的下标 u 

然后暴力跑出以分治点儿子为起点的链的值b,然后在map里查找inv[b]*k

代码:

#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb push_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pii pair<int, int>
#define mem(a, b) memset(a, b, sizeof(a))
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define fopen freopen("in.txt", "r", stdin);freopen("out.txt", "w", stout);
//head

const int MOD = 1e6 + 3;
const int INF = 0x7f7f7f7f;
const int N = 1e5 + 5;
int inv[MOD + 5], mp[MOD + 5], head[N], mxsz[N], sz[N], v[N], cnt = 0, rt = 0, n, k, ans1, ans2;
int deep[N], dis[N], id[N], top = 0;
bool vis[N];
struct edge {
    int to, nxt;
}edge[N*2];
void add_edge(int u, int v) {
    edge[cnt].to = v;
    edge[cnt].nxt = head[u];
    head[u] = cnt++;
}
void init() {
    inv[1] = 1;
    for (int i = 2; i < MOD; i++) inv[i] = (MOD - MOD/i) * 1LL * inv[MOD%i] % MOD;
}
void update(int x, int y) {
    int t = (1LL * inv[x] * k) % MOD;
    int now = mp[t];
    if(!now) return ;
    if(now > y) swap(now, y);
    if(now < ans1 || now == ans1 && y < ans2) ans1 = now, ans2 = y;
}
void get_rt(int o, int u) {
    sz[u] = 1, mxsz[u] = 0;
    for (int i = head[u]; ~i; i = edge[i].nxt) {
        if(edge[i].to != o && !vis[edge[i].to]) {
            get_rt(u, edge[i].to);
            sz[u] += sz[edge[i].to];
            mxsz[u] = max(mxsz[u], sz[edge[i].to]);
        }
    }
    mxsz[u] = max(mxsz[u], n - sz[u]);
    if(mxsz[u] < mxsz[rt]) rt = u;
}
void get_d(int o, int u) {
    deep[++top] = dis[u];
    id[top] = u;
    for (int i = head[u]; ~i; i = edge[i].nxt) {
        if(!vis[edge[i].to] && edge[i].to != o) {
            dis[edge[i].to] = (1LL * dis[u] * v[edge[i].to])%MOD;
            get_d(u, edge[i].to);
        }
    }
}
void solve(int u) {
    vis[u] = true;
    mp[v[u]] = u;
    for (int i = head[u]; ~i; i = edge[i].nxt) {
        if(!vis[edge[i].to]) {
            top = 0, dis[edge[i].to] = v[edge[i].to];
            get_d(u, edge[i].to);
            for (int j = 1; j <= top; j++) update(deep[j], id[j]);
            top = 0, dis[edge[i].to] = (1LL * v[u] * v[edge[i].to])%MOD;
            get_d(u, edge[i].to);
            for (int j = 1; j <= top; j++) {
                int t = deep[j];
                if(!mp[t] || id[j] < mp[t]) mp[t] = id[j];
            }
        }
    }
    mp[v[u]] = 0;
    for (int i = head[u]; ~i; i = edge[i].nxt) {
        if(!vis[edge[i].to]) {
            top = 0, dis[edge[i].to] = (1LL * v[u] * v[edge[i].to])%MOD;
            get_d(u, edge[i].to);
            for (int j = 1; j <= top; j++) mp[deep[j]] = 0;
        }
    }
    for (int i = head[u]; ~i; i = edge[i].nxt) {
        if(!vis[edge[i].to]) {
            mxsz[0] = n =  sz[edge[i].to];
            get_rt(rt = 0, edge[i].to);
            solve(rt);
        }
    }
}
int main() {
    init();
    int u, V;
    while(~scanf("%d%d", &n, &k)) {
        mem(head, -1);
        mem(vis, false);
        mem(mp, 0);
        cnt = 0;
        ans1 = ans2 = INF;
        for (int i = 1; i <= n; i++) scanf("%d", &v[i]);
        for (int i = 1; i < n; i++) scanf("%d%d", &u, &V), add_edge(u, V), add_edge(V, u);
        mxsz[0] = n;
        get_rt(rt = 0, 1);
        solve(rt);
        if(ans1 == INF) printf("No solution\n");
        else printf("%d %d\n", ans1, ans2);
    }
    return 0;
}

 

posted @ 2018-05-29 18:16  Wisdom+.+  阅读(161)  评论(0编辑  收藏  举报