CF1039D 题解
题目大意
给出一棵\(n\)个节点的树,对于\(1\)~\(n\)间的每一个数\(k\),你需要求出:最多能选出多少条互不相交的路径,每条路径的长度都为\(k\)。
\(Solution:\)
\(1\)~\(n\)不好算,先考虑单个的\(k\)。
令\(ans_i\)表示最多能选出多少条互不相交的路径,每条路径的长度都为\(i\),\(f_u\)表示回溯到\(u\)时最多能选出多少条互不相交的路径,每条路径的长度都为\(i\)(\(i\)就是前面的\(i\));
考虑贪心地\(dp\):对于一个点\(u\),我们考虑最大化以\(u\)为根的子树中完整的长度为\(k\)的路径条数,其次最大化未完成的链的长度;
正确性显然:由于树的特性,\(u\)节点只会在一个路径中被计数,那么我们用类似点分治的看法,设\(u\)是该路径上\(dep\)最浅的点,\(max1\)是\(u\)节点向下挂出的最长链的长度,\(max2\)是\(u\)向下挂出的次长链的长度,那么:
\((1)\)若\(max1+max2+1 \geq k\),那么我们显然要将\(max1,max2,u\)组成的路径计入答案,否则会用到\(u\)的祖先节点,这是不优的;
\((2)Otherwise\),我们就用\(max1,u\)组成的链继承上去
代码大概是这个样子的(看了\(lzy\)大佬的\(blog\),发现其实可以先\(dfs\)一遍跑出\(dfs\)序,然后直接在\(dfs\)序上操作,新技能\(get\))
void dfs(int u,int fa)
{
fat[u]=fa;
go(u)
{
int v=e[i].to;
if(v!=fa) dfs(v,u);
}
dfn[++idx]=u;
}
inline int solve(int k)
{
int ans=0;
fr(i,1,n) f[i]=1;
fr(i,1,n)
{
int u=dfn[i];
if(fat[u]&&f[fat[u]]&&f[u])
{
if(f[u]+f[fat[u]]>=k)
{
++ans,f[fat[u]]=0;
}
else f[fat[u]]=max(f[fat[u]],f[u]+1);
}
}
return ans;
}
然鹅这是\(O(n^2)\)的,考虑优化。可以发现有:\(ans_i \leq \frac{n}{i}\)(原因是路径长度为\(i\),那么最好情况即将所有点都用上,就会有\(\frac{n}{i}\)条路径)。
结合数据范围\(n \leq 10^5\),考虑根号分治:对于一个确定的\(k\),我们设一个阀值\(B\):
\((1)\)若\(k \leq B\),直接暴力\(dp\)(也就是上面的代码),复杂度\(O(nB)\);
\((2)\)若\(k > B\),此时必然有:\(ans_k \in [0,\frac{n}{B}]\),也就是只有\(\frac{n}{B}\)这么多个取值,然后\(f\)(即\(dp\)数组)显然具有单调不增的性质,也就是说中间有一段一段的dp值是一样的,那么我们考虑二分出这些段的边界,每次二分用solve()
来\(check\),复杂度\(O(\frac{n}{B} n \log_2 n)\)。
那么我们现在来分析阀值\(B\)的取值,由上面的分析可知,总复杂度\(O(nB+\frac{n}{B} n \log_2 n)=O(n(B+\frac{n}{B} \log n))\),由均值不等式:\(min=n \sqrt{n \log n}\),当且仅\(B=\frac{n}{B} \log n\)即\(B=\sqrt{n \log n}\)时取得。
上代码:
\(Code:\)
#include<bits/stdc++.h>
using namespace std;
namespace my_std
{
typedef long long ll;
typedef double db;
#define pf printf
#define pc putchar
#define fr(i,x,y) for(register int i=(x);i<=(y);++i)
#define pfr(i,x,y) for(register int i=(x);i>=(y);--i)
#define go(x) for(int i=head[u];i;i=e[i].nxt)
#define enter pc('\n')
#define space pc(' ')
#define fir first
#define sec second
#define MP make_pair
const int inf=0x3f3f3f3f;
const ll inff=1e15;
inline int read()
{
int sum=0,f=1;
char ch=0;
while(!isdigit(ch))
{
if(ch=='-') f=-1;
ch=getchar();
}
while(isdigit(ch))
{
sum=sum*10+(ch^48);
ch=getchar();
}
return sum*f;
}
inline void write(int x)
{
if(x<0)
{
x=-x;
pc('-');
}
if(x>9) write(x/10);
pc(x%10+'0');
}
inline void writeln(int x)
{
write(x);
enter;
}
inline void writesp(int x)
{
write(x);
space;
}
}
using namespace my_std;
const int N=1e5+50;
int n,B,idx,f[N],fat[N],dfn[N],head[N],cnt,ans[N];
struct edge
{
int to,nxt;
}e[N<<1];
inline void add(int u,int v)
{
e[++cnt].to=v;
e[cnt].nxt=head[u];
head[u]=cnt;
}
void dfs(int u,int fa)
{
fat[u]=fa;
go(u)
{
int v=e[i].to;
if(v!=fa) dfs(v,u);
}
dfn[++idx]=u;
}
inline int solve(int k)
{
int ans=0;
fr(i,1,n) f[i]=1;
fr(i,1,n)
{
int u=dfn[i];
if(fat[u]&&f[fat[u]]&&f[u])
{
if(f[u]+f[fat[u]]>=k)
{
++ans,f[fat[u]]=0;
}
else f[fat[u]]=max(f[fat[u]],f[u]+1);
}
}
return ans;
}
int main(void)
{
n=read();
B=sqrt(n*log(n)/log(2));
fr(i,1,n-1)
{
int u=read(),v=read();
add(u,v),add(v,u);
}
dfs(1,0);
//fr(i,1,n) writesp(dfn[i]);
ans[1]=n;
fr(i,2,B) ans[i]=solve(i);
for(int i=B+1,l,r;i<=n;i=l+1)
{
l=i,r=n;
int tmp=solve(i);
while(r-l>1)
{
int mid=(l+r)>>1;
if(solve(mid)==tmp) l=mid;
else r=mid;
}
fr(j,i,l) ans[j]=tmp;
}
fr(i,1,n) writeln(ans[i]);
return 0;
}