题解 Airplane Cliques (300iq的题)
题目大意
给定一棵\(n\)个节点的树。定义树上两点距离为它们之间边的数量。
我们说一对节点是友好的,当且仅当两点之间距离小于等于\(x\)。其中\(x\)是一个题目给定的常数。
我们说一个\(k\)个节点的集合是友好集合,当且仅当集合中任意两个节点都是友好的。
请对所有\(k=1\dots n\),求出恰有\(k\)个节点的友好集合数量。
数据范围:\(1\leq n\leq 3\times10^5\), \(0\leq x<n\)。
本题题解
考虑给你一个节点集合\(S\),如何判断它是否是友好集合?暴力做法是枚举集合里的每一对点,判断它们之间距离是否小于等于\(x\),如果用欧拉序LCA\(O(1)\)求距离,则这种判断方法时间复杂度\(O(|S|^2)\)。但这看上去太暴力了。有没有更好的方法呢?
考虑集合\(S\)里深度最大的节点\(u\)(这里的“深度”,可以是以选任意一个节点为根时的深度)。我们断言:集合\(S\)是好的,当且仅当\(S\)里所有节点到\(u\)的距离均小于等于\(x\)。必要性显然。充分性可以用反证法证明:如果存在两个节点间距离大于\(x\),那么两点中必有一个节点到\(u\)的距离也大于\(x\),在纸上画一画就看出来了。有了这个结论,判断集合\(S\)是否为友好集合的时间复杂度下降为\(O(|S|)\)。事实上,该结论还有更大的用处。
根据这个结论,我们可以想到一个粗略的,解决本题的方法。把所有节点,按深度从小到大排序,深度相同的节点顺序可以任意(这等价于,我们求出一个“bfs序”序列),记这个序列为\(b[1\dots n]\)。我们依次加入\(b\)序列里每个点,那么,只需要统计,在每个节点之前加入的节点中,距离它小于等于\(x\)的节点有多少个。相当于,我们认为,当前节点,就是集合里深度最大的节点,而在它之前加入的节点,就是【以当前点作为【深度最大节点】的集合】里的其他节点。具体来说,我们求出一个数组\(a[1\dots n]\):
那么,\(k\)的答案就是\(ans[k]=\sum_{i=1}^{n}{a[i]\choose k-1}\)。这种按\(b\)序列依次考虑每个节点的做法,有效地避免了对深度相同的节点重复计算的情况。因为当集合里存在好几个深度最大的节点时,我们默认选择\(b\)序列里排名最靠后的,作为该集合的“深度最大的节点”。
如果暴力求出\(a\)序列,再暴力计算\(ans\),时间复杂度\(O(n^2)\)。我们对这两部分分别考虑优化。
第一部分 求出\(a\)序列
具体来说,我们要支持两种操作:
- 标记一个节点。
- 给定节点\(u=a[i]\),查询所有已标记的节点中,到\(u\)距离小于等于\(x\)的节点数。
考虑使用点分树。点分树是基于原树,重构出来的树。它有几个特点:
- 点分树上也有\(n\)个节点,对应原树里每个点。
- 点分树上的每个子树,是原树上的一个连通块。也就是点分治时,以子树的根为分治中心的连通块。
- 点分树的深度不超过\(O(\log n)\)。
建点分树,其实就是做“点分治”的过程。每次从当前分治中心,向每个子树的分治中心连边,就得到点分树了。
考虑【查询所有已标记的节点中,到\(u\)距离小于等于\(x\)的节点数】。枚举这个“已标记的点”到点\(u\)的路径在点分树上经过的深度最浅的点,也就是点分树上两者的LCA。因为点分树深度不超过\(O(\log n)\),所以可以暴力枚举\(u\)在点分树上的所有祖先。记这个祖先为\(v\)。那么,问题转化为,查询点分树上点\(v\)的子树中,在原树上距离\(v\)小于等于\(x-\text{dis}(u,v)\)的节点数。当然,还要减去,点分树上\(v\)的【包含\(u\)的那个儿子】的子树里,满足条件的节点数。
我们可以对点分树上每个节点,开两个树状数组。第一个树状数组,记录点分树上它的子树内,在原树上距离它小于等于\(i\)的、已标记的节点数。第二个树状数组,记录点分树上它的子树内,在原树上距离【它在点分树上的父亲】小于等于\(i\)的、已标记的节点数。那么,标记一个节点时,我们只需要在它点分树上祖先的树状数组上单调修改。查询时,我们在祖先的树状数组上查询一个“前缀和”。
时间复杂度\(O(n\log^2n)\)。
还有一个注意点,每个节点的树状数组,大小显然不能直接开到\(n\),否则空间复杂度会变成\(O(n^2)\)。我们在建点分树时,统计出,以当前节点\(v\)为分治中心的这个连通块内,所有节点到\(v\)的最大距离。那么\(v\)的第一个树状数组,大小直接开到这个最大距离即可。这样所有“第一个树状数组”的大小之和,就是最大距离之和,显然小于等于所有连通块的大小之和。根据点分治的理论,这个和不超过\(O(n\log n)\)。对于第二个树状数组,其大小设置为,该连通块内,到点分树上父亲的最大距离。那么,对于一个节点\(v\),它所有儿子的“第二个树状数组”的大小之和,小于等于以\(v\)为分治中心的这整个连通块的大小。所以,所有节点的“第二个树状数组”大小之和也是\(O(n\log n)\)的。因此,空间复杂度\(O(n\log n)\)。
第二部分 根据\(a\)序列求出答案数组
我们前面说过,每个\(a[i]\),它对每个\(k\)的答案的贡献是:\({a[i]\choose k-1}\)。所以,\(ans[k]=\sum_{i=1}^{n}{a[i]\choose k-1}\)。不妨记:\(t[k-1]=ans[k]\),现在我们就是要求这个\(t\)数组,其中\(t[k]=\sum_{i=1}^{n}{a[i]\choose k}\)。按定义直接求,不管你是先枚举\(i\)再枚举\(k\),还是先枚举\(k\)再枚举\(i\),甚至你把\(a[i]\)序列排序搞一搞,时间复杂度都是\(O(n^2)\)的。
转化一下思路。发现,\(a[i]\)的值相同,对答案的贡献就是相同的。所以先求一个桶,\(c[i]=\sum_{j=1}^{n}[a[j]=i]\)。然后就可以得到:
设\(f[i]=c[i]i!\), \(g[i]=\frac{1}{i!}\)。那么,式子转化为\(t[k]=\sum_{i=k}^{n-1}f[i]g[i-k]\)。这其实也是一种卷积,我们很容易将它转化为常见的FFT的形式:
我们熟悉的卷积都是\(t_k=\sum_{i=0}^{k}a_ib_{k-i}\),而这里的形式是\(t_k=\sum_{i=k}^{n}f_ig_{i-k}\),考虑如何转化。注:这里为方便表述,我们令\(n=n-1\)。
仔细考虑卷积的本质,我们的目标是要让相乘的两个数下标之和为定值,而我们能做的事就是尝试翻转一个数列。不妨尝试翻转\(f\),即令\(f_i=f_{n-i}\)。
此时\(t_k=\sum_{i=k}^{n}f_{n-i}g_{i-k}\)。
把\(i\)换成\(i-k\),则\(t_k=\sum_{i=0}^{n-k}f_{n-k-i}g_{i}\)。发现\(f\)和\(g\)的下标和为\(n-k\),刚好是一个定值。此时如果把\(f\)和\(g\)卷起来,那么结果的第\(n-k\)项就是\(t_k\)了(也就是把结果序列翻转一下)。
因此,我们做一遍FFT就能求出答案。
本部分时间复杂度\(O(n\log n)\)。
结合第一部分的解法,我们在\(O(n\log^2n)\)的时间复杂度内解决了本题。
参考代码
//problem:ZR1277
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
const int MAXN=3e5,MAX_LOGN=19;
int n,X;
struct EDGE{int nxt,to;}edge[MAXN*2+5];
int head[MAXN+5],tot;
inline void add_edge(int u,int v){edge[++tot].nxt=head[u],edge[tot].to=v,head[u]=tot;}
bool vis[MAXN+5];
int f[MAXN+5],sz[MAXN+5],root,totsize;
void get_root(int u,int fa){
f[u]=0;
sz[u]=1;
for(int i=head[u];i;i=edge[i].nxt){
int v=edge[i].to;
if(v==fa||vis[v])continue;
get_root(v,u);
f[u]=max(f[u],sz[v]);
sz[u]+=sz[v];
}
f[u]=max(f[u],totsize-sz[u]);
if(!root||f[u]<f[root])root=u;
}
int div_fa[MAXN+5],div_dep[MAXN+5],maxdep[MAXN+5];
void div_dfs(int u,int fa){
sz[u]=1;
div_dep[u]=div_dep[fa]+1;
maxdep[u]=div_dep[u];
for(int i=head[u];i;i=edge[i].nxt){
int v=edge[i].to;
if(v==fa||vis[v])continue;
div_dfs(v,u);
sz[u]+=sz[v];
maxdep[u]=max(maxdep[u],maxdep[v]);
}
}
struct FenwickTree{
vector<int>c;
int size;
void ins(int p){
++p;//距离从0开始,而树状数组的下标需要从1开始
assert(p>=1 && p<=size);
for(;p<=size;p+=(p&(-p)))c[p]++;
}
int query(int p){
if(p<0)return 0;
++p;
p=min(p,size);
assert(p>=1 && p<=size);
int res=0;
for(;p;p-=(p&(-p)))res+=c[p];
return res;
}
void init(int _size){
c.resize(_size+5);
size=_size;
}
FenwickTree(){}
}fwk[MAXN+5],fwk_fa[MAXN+5];
void build_divtree(int u){
//建点分树
vis[u]=1;
div_dfs(u,0);
fwk[u].init(maxdep[u]);
//cout<<"build "<<u<<" "<<sz[u]<<endl;
for(int i=head[u];i;i=edge[i].nxt){
int v=edge[i].to;
if(vis[v])continue;
root=0;totsize=sz[v];
get_root(v,0);
div_fa[root]=u;
fwk_fa[root].init(maxdep[v]+5);
build_divtree(root);
}
}
int bfs_array[MAXN+5],cnt_bfn;
void bfs(){
queue<int>q;
q.push(1);
memset(vis,0,sizeof(bool)*(n+1));
while(!q.empty()){
int u=q.front();
q.pop();
vis[u]=1;
bfs_array[++cnt_bfn]=u;
for(int i=head[u];i;i=edge[i].nxt){
int v=edge[i].to;
if(vis[v])continue;
q.push(v);
}
}
//for(int i=1;i<=n;++i)cout<<bfs_array[i]<<" ";cout<<endl;
}
int dep[MAXN+5],dfn[MAXN+5],arr[MAXN*2+5],st[MAXN*2+5][MAX_LOGN+1],cnt,_log2[MAXN*2+5];
void eulerLCA_dfs(int u,int fa){
dep[u]=dep[fa]+1;
dfn[u]=++cnt;
arr[cnt]=u;
for(int i=head[u];i;i=edge[i].nxt){
int v=edge[i].to;
if(v==fa)continue;
eulerLCA_dfs(v,u);
arr[++cnt]=u;
}
}
int get_lca(int u,int v){
int l=dfn[u],r=dfn[v];
if(l>r)swap(l,r);
int k=_log2[r-l+1];
return dep[st[l][k]]<dep[st[r-(1<<k)+1][k]]?st[l][k]:st[r-(1<<k)+1][k];
}
int get_dist(int u,int v){
if(u==v)return 0;
int lca=get_lca(u,v);
return dep[u]+dep[v]-dep[lca]*2;
}
void eulerLCA_init(){
eulerLCA_dfs(1,0);
for(int i=1;i<=cnt;++i)st[i][0]=arr[i];
for(int j=1;j<=MAX_LOGN;++j){
for(int i=1;i+(1<<(j-1))<=cnt;++i){
st[i][j]=dep[st[i][j-1]]<dep[st[i+(1<<(j-1))][j-1]]?st[i][j-1]:st[i+(1<<(j-1))][j-1];
}
}
_log2[0]=-1;
for(int i=1;i<=n*2;++i)_log2[i]=_log2[i>>1]+1;
assert(_log2[n*2]<=MAX_LOGN);
}
int a[MAXN+5],c[MAXN+5];
const int MOD=998244353;
inline int mod1(int x){return x<MOD?x:x-MOD;}
inline int mod2(int x){return x<0?x+MOD:x;}
inline void add(int& x,int y){x=mod1(x+y);}
inline void sub(int& x,int y){x=mod2(x-y);}
inline int pow_mod(int x,int i){int y=1;while(i){if(i&1)y=(ll)y*x%MOD;x=(ll)x*x%MOD;i>>=1;}return y;}
int fac[MAXN+5],ifac[MAXN+5];
inline int comb(int n,int k){
if(n<k)return 0;
return (ll)fac[n]*ifac[k]%MOD*ifac[n-k]%MOD;
}
void facinit(int lim=MAXN){
fac[0]=1;
for(int i=1;i<=lim;++i)fac[i]=(ll)fac[i-1]*i%MOD;
ifac[lim]=pow_mod(fac[lim],MOD-2);
for(int i=lim-1;i>=0;--i)ifac[i]=(ll)ifac[i+1]*(i+1)%MOD;
}
namespace Poly{
int rev[MAXN*4+5];
int f[MAXN*4+5],g[MAXN*4+5];
void NTT(int *a,int n,int flag){
for(int i=0;i<n;++i)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int i=1;i<n;i<<=1){
int T=pow_mod(3,(MOD-1)/(i<<1));
if(flag==-1) T=pow_mod(T,MOD-2);
for(int j=0;j<n;j+=(i<<1)){
for(int k=0,t=1;k<i;++k,t=(ll)t*T%MOD){
int Nx=a[j+k],Ny=(ll)a[i+j+k]*t%MOD;
a[j+k]=mod1(Nx+Ny);
a[i+j+k]=mod2(Nx-Ny);
}
}
}
if(flag==-1){
int invn=pow_mod(n,MOD-2);
for(int i=0;i<n;++i)a[i]=(ll)a[i]*invn%MOD;
}
}
void mul(int n){
int lim=1,ct=0;
while(lim<=n+n)lim<<=1,ct++;
for(int i=0;i<lim;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<(ct-1));
NTT(f,lim,1);
NTT(g,lim,1);
for(int i=0;i<lim;++i)f[i]=(ll)f[i]*g[i]%MOD;
NTT(f,lim,-1);
}
}//namespace Poly
int main() {
freopen("airplane.in","r",stdin);
freopen("airplane.out","w",stdout);
cin>>n>>X;
for(int i=1;i<n;++i){
int u,v;cin>>u>>v;
add_edge(u,v);
add_edge(v,u);
}
root=0;totsize=n;get_root(1,0);
build_divtree(root);
bfs();
eulerLCA_init();
for(int i=1;i<=n;++i){
int u=bfs_array[i];
a[u]=fwk[u].query(X);
for(int curu=div_fa[u],son=u;
curu!=0;
son=curu,curu=div_fa[curu]){
int d=get_dist(curu,u);
a[u]+=fwk[curu].query(X-d)-fwk_fa[son].query(X-d);
}
for(int curu=u,son=0;
curu!=0;
son=curu,curu=div_fa[curu]){
int d=get_dist(curu,u);
fwk[curu].ins(d);
if(curu!=u)fwk_fa[son].ins(d);
}
}
//for(int i=1;i<=n;++i)cout<<a[i]<<" ";cout<<endl;
for(int i=1;i<=n;++i)c[a[i]]++;
facinit();
for(int i=0;i<n;++i)Poly::f[i]=(ll)c[i]*fac[i]%MOD;
reverse(Poly::f,Poly::f+n);
for(int i=0;i<n;++i)Poly::g[i]=ifac[i];
Poly::mul(n);
reverse(Poly::f,Poly::f+n);
for(int i=1;i<=n;++i){
int ansi=(ll)Poly::f[i-1]*ifac[i-1]%MOD;
cout<<ansi<<" \n"[i==n];
}
return 0;
}