BZOJ2599:[IOI2011]Race

浅谈树分治:https://www.cnblogs.com/AKMer/p/10014803.html

题目传送门:https://www.lydsy.com/JudgeOnline/problem.php?id=2599

我们设\(f_i\)为长度为\(i\)的路径边数最小可以是多少,依次遍历当前根的子树,先用\(cnt+f[k-dis]\)更新答案,再遍历第二遍当前子树更新\(f\)数组。\(cnt\)表示根到当前点一共经过了多少条边。因为\(k\)的范围是\(10^6\)级别的,每次处理当前联通块前把\(f\)数组全部赋成极大值会很慢,所以我们每次更新\(f\)数组的时候把被改动过的\(dis\)用栈记录下来,每次处理完当前联通块就弹栈并且把相应的\(f\)数组初始化,这样做就是\(O(n)\)级别的了。

如果用边分做的话,记得把新建的边权值赋成\(-1\),因为可能会有边权为\(0\)的边,然后统计\(cnt\)的时候只有在碰到权值不为\(-1\)的边才算。

时间复杂度:\(O(nlogn)\)

空间复杂度:\(O(n)\)

点分治版代码如下:

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int maxn=2e5+5,maxm=1e6+5;;

bool vis[maxn],insta[maxm];
int n,m,tot,mx,rt,N,ans,top;
int siz[maxn],f[maxm],sta[maxn];
int now[maxn],pre[maxn*2],son[maxn*2],val[maxn*2];

int read() {
    int x=0,f=1;char ch=getchar();
    for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
    for(;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
    return x*f;
}

void add(int a,int b,int c) {
    pre[++tot]=now[a];
    now[a]=tot,son[tot]=b,val[tot]=c;
}

void find_rt(int fa,int u) {
    int res=0;siz[u]=1;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[v]&&v!=fa)find_rt(u,v),siz[u]+=siz[v],res=max(res,siz[v]);
    res=max(res,N-siz[u]);
    if(res<mx)mx=res,rt=u;
}

void query(int fa,int u,int cnt,int dis) {
    if(dis<=m)ans=min(ans,cnt+f[m-dis]);siz[u]=1;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[v]&&v!=fa)query(u,v,cnt+1,dis+val[p]),siz[u]+=siz[v];
}

void solve(int fa,int u,int cnt,int dis) {
    if(dis>m)return;
    f[dis]=min(f[dis],cnt);
    if(!insta[dis])sta[++top]=dis,insta[dis]=1;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[v]&&v!=fa)solve(u,v,cnt+1,dis+val[p]);
}

void work(int u,int size) {
    N=size,mx=rt=n+1,find_rt(0,u),u=rt,vis[u]=1;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[v])query(u,v,1,val[p]),solve(u,v,1,val[p]);
    ans=min(ans,f[m]);
    while(top) {
        f[sta[top]]=f[m+1];
        insta[sta[top--]]=0;
    }
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[v])work(v,siz[v]);
}

int main() {
    n=read(),m=read();
    for(int i=1;i<n;i++) {
        int a=read()+1,b=read()+1,c=read();
        add(a,b,c),add(b,a,c);
    }
    memset(f,127/3,sizeof(f));
    ans=f[0];work(1,n);
    if(ans==f[m+1])puts("-1");
    else printf("%d\n",ans);
    return 0;
}

边分治版代码如下:

#include <cmath>
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
typedef pair<int,int> pii;
#define fr first
#define sc second

const int maxn=4e5+5,maxm=1e6+5;

bool vis[maxn],insta[maxm];
int f[maxm],siz[maxn],sta[maxn];
int n,m,tot,cnt,mx,id,N,top,ans,tmp1,tmp2;
int now[maxn],pre[maxn*2],son[maxn*2],val[maxn*2];

vector<pii>to[maxn];
vector<pii>::iterator it;

int read() {
    int x=0,f=1;char ch=getchar();
    for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
    for(;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
    return x*f;
}

void add(int a,int b,int c) {
    pre[++tot]=now[a];
    now[a]=tot,son[tot]=b,val[tot]=c;
}

void find_son(int fa,int u) {
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(v!=fa)to[u].push_back(make_pair(v,val[p])),find_son(u,v);
}

void rebuild() {
    tot=1;memset(now,0,sizeof(now));
    for(int i=1;i<=cnt;i++) {
        int size=to[i].size();
        if(size<=2) {
            for(it=to[i].begin();it!=to[i].end();it++) {
                pii tmp=*it;
                add(i,tmp.fr,tmp.sc),add(tmp.fr,i,tmp.sc);
            }
        }
        else {
            pii u1=make_pair(++cnt,-1),u2;
            if(size==3)u2=to[i].front();
            else u2=make_pair(++cnt,-1);
            add(i,u1.fr,u1.sc),add(u1.fr,i,u1.sc);
            add(i,u2.fr,u2.sc),add(u2.fr,i,u2.sc);
            if(size==3) {
                for(int j=1;j<=2;j++)
                    to[cnt].push_back(to[i].back()),to[i].pop_back();
            }
            else {
                int p=0;
                for(it=to[i].begin();it!=to[i].end();it++) {
                    if(!p)to[cnt-1].push_back(*it);
                    else to[cnt].push_back(*it);p^=1;
                }
            }
        }
    }
}

void find_edge(int fa,int u) {
    siz[u]=1;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[p>>1]&&v!=fa) {
            find_edge(u,v),siz[u]+=siz[v];
            if(abs(N-2*siz[v])<mx)mx=abs(N-2*siz[v]),id=p>>1;
        }
}


void solve(int fa,int u,int num,int dis) {
    if(dis<=m) {
        f[dis]=min(f[dis],num);
        if(!insta[dis])sta[++top]=dis,insta[dis]=1;
    }siz[u]=1;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[p>>1]&&v!=fa)
			solve(u,v,num+(val[p]!=-1),dis+(val[p]!=-1)*val[p]),siz[u]+=siz[v];
}

void query(int fa,int u,int num,int dis) {
    if(dis+tmp1<=m)ans=min(ans,num+f[m-dis-tmp1]+tmp2);siz[u]=1;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[p>>1]&&v!=fa)
			query(u,v,num+(val[p]!=-1),dis+(val[p]!=-1)*val[p]),siz[u]+=siz[v];
}

void work(int u,int size) {
    N=size,mx=id=cnt+1,find_edge(0,u);
    if(id==cnt+1)return;vis[id]=1;
    int u1=son[id<<1],u2=son[id<<1|1];
    tmp1=(val[id<<1]!=-1)*val[id<<1],tmp2=(val[id<<1]!=-1);
    solve(0,u1,0,0),query(0,u2,0,0);
    while(top) {
        f[sta[top]]=f[m+1];
        insta[sta[top--]]=0;
    }
    work(u1,siz[u1]),work(u2,siz[u2]);
}

int main() {
    cnt=n=read(),m=read();
    for(int i=1;i<n;i++) {
        int a=read()+1,b=read()+1,c=read();
        add(a,b,c),add(b,a,c);
    }
    find_son(0,1),rebuild();
    memset(f,127/3,sizeof(f));
    ans=f[0];work(1,cnt);
    if(ans==f[m+1])puts("-1");
    else printf("%d\n",ans);
    return 0;
}
posted @ 2018-12-02 10:05  AKMer  阅读(147)  评论(0编辑  收藏  举报