[USACO18JAN]Cow at Large P
XIII.[USACO18JAN]Cow at Large P
这题我做的时候时限1s,然后卡不过去……之后不得不发帖请求把时限调大到题面中的4s
假设当前询问了点\(rt\),那么我们把这棵树变成以\(rt\)为根,设\(dep_i\)为此刻\(i\)节点的深度。
我们再令\(f_i\)表示\(i\)节点距离最近的叶子的距离(这个可以通过二次扫描与换根法通过DP求出,不是重点,就跳过了)。则如果一个节点有\(dep_i\geq f_i\),则这个点便可以被那个叶子看住。
显然,一个叶子如果放了农民,被看住的一定是一棵子树,而子树的树根即为最高的那个具有\(dep_i\geq f_i\)的点。这也意味着所有最高的符合条件的点,都代表了一棵放了农民的子树,也即农民的数量。
我们再次审视这个条件,发现它即为所有满足\(dep_i\geq f_i\ \land\ dep_{fa_i}<f_{fa_i}\)的点的数量。
因为前一半一定是对一棵子树全部成立,因此我们考虑推出某种式子,使得一整棵子树的该式之和为一。我们考虑从度数下手。
若设\(deg_x\)为\(x\)节点的度数的话,则对于一棵常规的树,\(\sum deg_x=2(n-1)\)应是成立的,因为每条边会被统计两次。
但是我们现在子树的根节点的度数还包含连向它父亲的那条边,因此有\(\sum deg_x=2n-1\)。当然,这个特例在树根处无效——但是如果树根都符合条件,则树根只有可能是叶子。因此叶子节点特判答案为\(1\)即可。
我们可以将\(\sum deg_x=2n-1\)变化成\(\sum2-deg_x=1\),这就是我们之前想要的那个和为一的式子。
我们还要满足\(dep_i\geq f_i\)这一条件。因此现在答案即为\(\sum\limits_{dep_i\geq f_i}2-deg_i\)。
这个式子就可以用点分治求解了。我们将\(dep_i\)拆成两部分,即\(\operatorname{dis}(lca,i)\)与\(\operatorname{dis}(lca,rt)\)。则我们现在即可将式子表示成\(\operatorname{dis}(lca,i)+\operatorname{dis}(lca,rt)\geq f_i\)。将可以预处理的\(\operatorname{dis}(lca,i)\)移过去,便得到\(\operatorname{dis}(lca,rt)\geq f_i-\operatorname{dis}(lca,i)\)。将式子右边与\(2-deg_x\)make_pair
压入vector<pair<int,int> >
并排序即可通过前缀和+二分求出所有满足条件的点的\(2-deg_x\)的和。
复杂度\(O(n\log^2n)\)。
建出点分治树然后处理的代码:
#pragma GCC optimize(3)
#include<bits/stdc++.h>
using namespace std;
int n,fa[200100],dep[200100],in[200100],tot,mn[400100][20],LG[400100],f[200100];
namespace Tree{
vector<int>v[200100];
int sz[200100],SZ,msz[200100],ROOT;
bool vis[200100];
void getsz(int x,int fa){
sz[x]=1;
for(auto y:v[x])if(!vis[y]&&y!=fa)getsz(y,x),sz[x]+=sz[y];
}
void getroot(int x,int fa){
sz[x]=1,msz[x]=0;
for(auto y:v[x])if(!vis[y]&&y!=fa)getroot(y,x),sz[x]+=sz[y],msz[x]=max(msz[x],sz[y]);
msz[x]=max(msz[x],SZ-sz[x]);
if(msz[x]<msz[ROOT])ROOT=x;
}
void solve(int x){
getsz(x,0);
vis[x]=true;
for(auto y:v[x]){
if(vis[y])continue;
ROOT=0,SZ=sz[y],getroot(y,0),fa[ROOT]=x,solve(ROOT);
}
}
void getural(int x,int fa){
mn[++tot][0]=x,in[x]=tot;
for(auto y:v[x])if(y!=fa)dep[y]=dep[x]+1,getural(y,x),mn[++tot][0]=x;
}
void getdfsI(int x,int fa){
f[x]=(v[x].size()==1?0:0x3f3f3f3f);
for(auto y:v[x])if(y!=fa)getdfsI(y,x),f[x]=min(f[x],f[y]+1);
}
void getdfsII(int x,int fa){
for(auto y:v[x])if(y!=fa)f[y]=min(f[y],f[x]+1),getdfsII(y,x);
}
}
int MIN(int i,int j){
return dep[i]<dep[j]?i:j;
}
int LCA(int i,int j){
if(i>j)swap(i,j);
int k=LG[j-i+1];
return MIN(mn[i][k],mn[j-(1<<k)+1][k]);
}
int DIS(int i,int j){
return dep[i]+dep[j]-dep[LCA(in[i],in[j])]*2;
}
namespace cdt{
vector<pair<int,int> >sf[200100],pa[200100];
void init(int x){
int val=2-Tree::v[x].size();
for(int u=x;u;u=fa[u]){
sf[u].push_back(make_pair(f[x]-DIS(x,u),val));
if(fa[u])pa[u].push_back(make_pair(f[x]-DIS(x,fa[u]),val));
}
}
void prep(){
for(int i=1;i<=n;i++)init(i);
for(int i=1;i<=n;i++){
sf[i].push_back(make_pair(0x80808080,0));
pa[i].push_back(make_pair(0x80808080,0));
sort(sf[i].begin(),sf[i].end()),sort(pa[i].begin(),pa[i].end());
// printf("%d:\n",i);
// for(int j=1;j<sf[i].size();j++)printf("(%d,%d)",sf[i][j].first,sf[i][j].second);puts("");
// for(int j=1;j<pa[i].size();j++)printf("(%d,%d)",pa[i][j].first,pa[i][j].second);puts("");
for(int j=1;j<sf[i].size();j++)sf[i][j].second+=sf[i][j-1].second;
for(int j=1;j<pa[i].size();j++)pa[i][j].second+=pa[i][j-1].second;
}
}
int ask(int x){
if(f[x]==0)return 1;
int res=0;
for(int u=x;u;u=fa[u]){
int p=upper_bound(sf[u].begin(),sf[u].end(),make_pair(DIS(x,u),(int)0x3f3f3f3f))-sf[u].begin()-1;
res+=sf[u][p].second;
if(!fa[u])break;
p=upper_bound(pa[u].begin(),pa[u].end(),make_pair(DIS(x,fa[u]),(int)0x3f3f3f3f))-pa[u].begin()-1;
res-=pa[u][p].second;
}
return res;
}
}
void read(int &x){
x=0;
char c=getchar();
while(c>'9'||c<'0')c=getchar();
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
}
int main(){
read(n);
for(int i=1,x,y;i<n;i++)read(x),read(y),Tree::v[x].push_back(y),Tree::v[y].push_back(x);
Tree::msz[0]=0x3f3f3f3f,Tree::SZ=n,Tree::getroot(1,0),Tree::solve(Tree::ROOT);
Tree::getdfsI(1,0),Tree::getdfsII(1,0);
Tree::getural(1,0);
for(int i=2;i<=tot;i++)LG[i]=LG[i>>1]+1;
for(int j=1;j<=LG[tot];j++)for(int i=1;i+(1<<j)-1<=tot;i++)mn[i][j]=MIN(mn[i][j-1],mn[i+(1<<(j-1))][j-1]);
cdt::prep();
for(int i=1;i<=n;i++)printf("%d\n",cdt::ask(i));
return 0;
}
没建点分治树直接点分治的代码:
#pragma GCC optimize(3)
#include<bits/stdc++.h>
using namespace std;
int n,dep[200100],f[200100],res[200100];
vector<int>v[200100];
int sz[200100],SZ,msz[200100],ROOT;
bool vis[200100];
void getdfsI(int x,int fa){
f[x]=(v[x].size()==1?0:0x3f3f3f3f);
for(auto y:v[x])if(y!=fa)getdfsI(y,x),f[x]=min(f[x],f[y]+1);
}
void getdfsII(int x,int fa){for(auto y:v[x])if(y!=fa)f[y]=min(f[y],f[x]+1),getdfsII(y,x);}
void getszdep(int x,int fa){
sz[x]=1;
for(auto y:v[x])if(!vis[y]&&y!=fa)dep[y]=dep[x]+1,getszdep(y,x),sz[x]+=sz[y];
}
void getroot(int x,int fa){
sz[x]=1,msz[x]=0;
for(auto y:v[x])if(!vis[y]&&y!=fa)getroot(y,x),sz[x]+=sz[y],msz[x]=max(msz[x],sz[y]);
msz[x]=max(msz[x],SZ-sz[x]);
if(msz[x]<msz[ROOT])ROOT=x;
}
vector<pair<int,int> >vp;
void getwrite(int x,int fa){
vp.push_back(make_pair(f[x]-dep[x],2-v[x].size()));
for(auto y:v[x])if(!vis[y]&&y!=fa)getwrite(y,x);
}
void getread(int x,int fa,int k){
int p=upper_bound(vp.begin(),vp.end(),make_pair(dep[x],(int)0x3f3f3f3f))-vp.begin()-1;
res[x]+=k*vp[p].second;
for(auto y:v[x])if(!vis[y]&&y!=fa)getread(y,x,k);
}
void calc(int x){
vp.clear(),vp.push_back(make_pair(0x80808080,0)),getwrite(x,0),sort(vp.begin(),vp.end());
for(int i=1;i<vp.size();i++)vp[i].second+=vp[i-1].second;
getread(x,0,1);
for(auto y:v[x]){
if(vis[y])continue;
vp.clear(),vp.push_back(make_pair(0x80808080,0)),getwrite(y,x),sort(vp.begin(),vp.end());
for(int i=1;i<vp.size();i++)vp[i].second+=vp[i-1].second;
getread(y,x,-1);
}
}
void solve(int x){
dep[x]=0,getszdep(x,0);
calc(x);
vis[x]=true;
for(auto y:v[x]){
if(vis[y])continue;
ROOT=0,SZ=sz[y],getroot(y,0),solve(ROOT);
}
}
void read(int &x){
x=0;
char c=getchar();
while(c>'9'||c<'0')c=getchar();
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
}
int main(){
read(n);
for(int i=1,x,y;i<n;i++)read(x),read(y),v[x].push_back(y),v[y].push_back(x);
getdfsI(1,0),getdfsII(1,0);
msz[0]=0x3f3f3f3f,SZ=n,getroot(1,0),solve(ROOT);
for(int i=1;i<=n;i++)printf("%d\n",v[i].size()==1?1:res[i]);
return 0;
}