【BZOJ 2599】【IOI 2011】Race 点分治

裸的点分治,然而我因为循环赋值$s$时把$i <= k$写成$i <= n$了,WA了好长时间

#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 200100
#define inf 2147483647
#define max(a,b) (a)>(b)?(a):(b)
#define min(a,b) (a)<(b)?(a):(b)
#define read(x) x=getint()
using namespace std;
inline int getint() {
    int fh = 1, k = 0; char c = getchar();
    for(; c < '0' || c > '9'; c = getchar())
        if (c == '-') fh = -1;
    for(; c >= '0' && c <= '9'; c = getchar())
        k = k * 10 + c - '0';
    return k * fh;
}
struct node {
    int nxt, to, w;
} E[N << 1];
bool vis[N];
int cnt = 0, s[1000100], rtm = inf, root, sz[N], dist[N], deep[N], n, k, ans, point[N];
inline void ins(int x, int y, int z) {++cnt; E[cnt].nxt = point[x]; E[cnt].to = y; E[cnt].w = z; point[x] = cnt;}
inline void fdrt(int x, int fa, int sh) {
    sz[x] = 1;
    int ma = 0;
    for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) {
        int v = E[tmp].to;
        if (vis[v] || v == fa)
            continue;
        fdrt(v, x, sh);
        sz[x] += sz[v];
        ma = max(ma, sz[v]);
    }
    ma = max(ma, sh - ma);
    if (ma < rtm) {
        rtm = ma;
        root = x;
    }
}
inline void work(int x, int fa) {
    if (dist[x] <= k)
		ans = min(ans, deep[x] + s[k - dist[x]]);
    for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) {
        int v = E[tmp].to;
        if (vis[v] || v == fa)
            continue;
        dist[v] = dist[x] + E[tmp].w;
        deep[v] = deep[x] + 1;
        work(v, x);
    }
}
inline void sfill(int x, int fa) {
	if (dist[x] < k)
		s[dist[x]] = min(s[dist[x]], deep[x]);
    for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) {
        int v = E[tmp].to;
        if (vis[v] || v == fa)
            continue;
        sfill(v, x);
    }
}
inline void emp(int x, int fa) {
	if (dist[x] < k)
    	s[dist[x]] = n + 1;
    for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) {
        int v = E[tmp].to;
        if (vis[v] || v == fa)
            continue;
        emp(v, x);
    }
}
inline void dfs(int x, int sh) {
    vis[x] = 1;
    s[0] = 0; //不能落下这个点!!因为后面会更新不到,而且有可能会更改s[0]的值 
    for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) {
        int v = E[tmp].to;
        if (vis[v])
            continue;
        dist[v] = E[tmp].w;
        deep[v] = 1;
        work(v, x);
        sfill(v, x);
    }
    for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) {
    	int v = E[tmp].to;
    	if (vis[v])
    		continue;
		emp(v, x);
    }
    for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) {
        int v = E[tmp].to;
        if (vis[v])
            continue;
        int ss = sz[v] < sz[x] ? sz[v]: sh - sz[x];
        rtm = inf;
        fdrt(v, x, ss);
        dfs(root, ss);
    }
}
int main() {
    read(n); read(k);
    int a,b,c;
    for(int i = 1; i < n; ++i) {
        read(a); read(b); read(c); ++a; ++b;
        ins(a, b, c);
        ins(b, a, c);
    }
    ans = n;
    memset(vis, 0, sizeof(vis));
    fdrt(1, -1, n);
    for(int i = 0; i <= k; ++i)
		s[i] = n + 1;
    dfs(1, n);
    printf("%d\n", ans == n ? -1 : ans);
    return 0;
}

然后就可以了

posted @ 2016-03-30 14:31  abclzr  阅读(275)  评论(0编辑  收藏  举报