[十二省联考 2019] 希望
链接_
你知道,代码非常长。在长度 10k 的代码间大眼观察,能过样例的程序浩如烟海,可以AC的代码却寥若晨星。但你没时间绝望,甚至没时间调出能过样例的程序。
你只能直接输出样例的结果。
那就是希望。
即便只能爆0,也是光明。
首先“存在**”很不好计数,因为这样同种情况中出现多个可行点只会记录一次。这意味着式子很难被展开。
但是我们发现最后的可行点一定是一个联通块。这意味着这些可行点一定会构成一颗树。
也就是答案 \(\displaystyle ans=\sum_{\text{所有合法方案}p}\left(\sum_{v\in V}[v\text{被所有集合包含}]-\sum_{in E}[e\text{被所有集合包含}]\right)\)
这样就可以拆开计算。由于队伍之间没有任何区别,所以一个点被包含的次数就是被一个队伍包含次数的 \(k\) 次方,即
\(d(x)\) 表示点/边 \(x\) 被直径不超过 \(L\) 的联通块包含的方案数。
然后我们就可以大力dp。令 \(f_{u,i}\) 表示 \(u\) 子树中包含 \(u\) 且距离 \(u\) 不超过 \(i\) 的联通块数量。
令 \(g_{u,i}\) 表示 \(u\) 子树外(不包括 \(u\) 的子树除了 \(u\) 部分)中包含 \(u\) 且距离 \(u\) 不超过 \(i\) 的联通块数量。
转移有:
这样对于一个非根节点我们再顺便统计其到父亲边的贡献。可以在 \(O(nL)\) 时间内处理出结果。
但是这个复杂度不够优秀啊。
一般看着简单的式子都是可以优化的,可以发现对于上面的部分,只要维护整体平移,区间乘即可。这个直接可持久化线段树上搞一搞就好了。
时间复杂度 \(O(n\log n)\)。但是这样常数有点大,而这题很恶心,这样做会T掉。
我们需要一个线性的东西:长链剖分。
首先考虑处理 \(f\)。这部分是自底向上转移的,所以我们令每个点继承它长儿子的信息。
至于如何方便地继承有一个很套路的做法,就是先开一个很大的数组,然后每个点的 \(f\) 对应一个指向该数组某个位置的指针,然后模拟申请内存的方式,在数组末尾记录一个指针表示“当前已经申请的位置”。这样既可以优化常数(比new不知道快到哪里去了),也可以有效防止操作不当导致的 MLE 和 RE。
这样区间平移也很好 \(O(1)\) 处理,让长儿子的 \(f\) 指向当前节点 \(f\) 的第2位即可。同时当前节点的权值也直接继承长儿子的就好了。
然后是打标记。我们需要维护:后缀赋0,整体+1,后缀乘。
这个看似很难 \(O(1)\) 处理。但是我们发现在同一时刻一种标记只会有一个,所以我们完全可以离线处理标记,在合并的时候再处理上去。
为什么要把赋0和乘分开呢?因为0是没有逆元的,没法直接优化,所以要单独打tag。
由于每个节点只会被插入一次,所以总时间复杂度是正确的。具体实现其实细节很多。
然后考虑处理 \(g\)。这部分是自顶向下的。所以我们令每个点先处理完轻儿子的信息,再把所有信息一起传给长儿子。
首先可以发现:\([0,L-dep_u]\) 这部分的距离是没有本质区别的,因为无论子树中怎么取,这段链都不会成为瓶颈。那么我们可以直接把它变成 \(L-dep_u\)。
这样每次继承的数组就变成了 \(O(dep_u)\),可以保证时间复杂度。
考虑怎么处理贡献。首先可以发现我们需要求出去掉 \(u\) 的某个子节点 \(v\) 后所有 \(f\) 数组合并的结果。
对于后半段我们可以依次合并过去保证复杂度,但是对于前半段我们没办法删除。
所以一种方式就是记录每次合并后的所有修改,然后依次回退到上一次修改。由于总修改次数的复杂度是正确的,所以总回退的时空复杂度也是正确的。
然后我们再暴力合并上重儿子的结果,下传。
最后我们把合并好的数组继承给重儿子。同样每个点只会被继承 \(O(1)\) 次,总复杂度 \(O(n)\)。
同样实现细节还是很多。
最后还有一个瓶颈在于求逆元。由于 \(n\) 太大了,我们无法暴力 \(O(n\log a)\) 求出所有数字的逆元。
但是我们发现,我们只需要预处理求出所有点内部联通块数量时顺便求出这些值的逆元即可,所以考虑使用 \(O(n+\log a)\) 的离线求逆元。
总复杂度 \(O(n)\)。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
#define N 1000010
#define ll long long
#define mod 998244353
using namespace std;
int ksm(int a,int b=mod-2)
{
int r=1;
for(;b;b>>=1)
{
if(b&1) r=1ll*r*a%mod;
a=1ll*a*a%mod;
}
return r;
}
int len[N],son[N],f1[N],n,L,k;
int nxt[N<<1],to[N<<1],head[N],cnt;
void add(int u,int v)
{
nxt[++cnt]=head[u];
to[cnt]=v;
head[u]=cnt;
}
int id[N],idn,val[N];
void dfs(int u,int p)
{
f1[u]=len[u]=1;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v==p) continue;
dfs(v,u);
if(len[v]>len[son[u]]) son[u]=v,len[u]=len[v]+1;
f1[u]=1ll*f1[u]*f1[v]%mod;
}
f1[u]=(f1[u]+1)%mod;
if(f1[u]) val[id[u]=++idn]=f1[u];
}
struct node{
int a,d,iv,l,w;
};
node o;
int ar[N<<3],*f[N*2],*g[N*2],par;
int* _new(int k){par+=k;return ar+(par-k);}
int inv[N];
#define P pair
#define Pr pair<int,int>
#define MP make_pair
#define fi first
#define se second
int ra[N],rb[N];
#define times(x,v1,v2) x.a=1ll*x.a*v1%mod,x.d=1ll*x.d*v1%mod,x.iv=1ll*x.iv*v2%mod
#define addf(f,p) f[p.l++]=p.w
namespace F{
node tag[N<<1];
int ans[N];
vector<P<node,vector<Pr > > >bk[N];
inline int get(int u,int k)
{
if(k<tag[u].l) return (1ll*f[u][k]*tag[u].a+tag[u].d)%mod;
return (1ll*tag[u].w*tag[u].a+tag[u].d)%mod;
}
void put(int u,int k,int v){f[u][k]=1ll*(v-tag[u].d+mod)*tag[u].iv%mod;}
void merge(int u,int v,int l)
{
node tmp=tag[u];
vector<Pr >s;
for(int i=1;i<=l;i++)
{
s.push_back(MP(i,f[u][i]));
int t=get(v,i-1);
if(i==tag[u].l) addf(f[u],tag[u]);
put(u,i,1ll*get(u,i)*t%mod);
}
if(l<L)
{
int t=get(v,l);
if(!t) tag[u].l=l+1,tag[u].w=mod-1ll*tag[u].d*tag[u].iv%mod;
else
{
int iv=inv[id[v]];
s.push_back(MP(0,f[u][0]));
for(int i=0;i<=l;i++)
put(u,i,1ll*get(u,i)*iv%mod);
times(tag[u],t,iv);
}
}
if(u<=n) bk[u].push_back(MP(tmp,s));
}
void dfs(int u,int p)
{
if(son[u])
{
f[son[u]]=f[u]+1;
dfs(son[u],u);
tag[u]=tag[son[u]];
}
else tag[u]=o;
put(u,0,1);
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v==p || v==son[u]) continue;
f[v]=_new(len[v]);
dfs(v,u);
merge(u,v,min(len[v]-1,L));
}
ra[u]=get(u,min(len[u]-1,L-1));rb[u]=get(u,min(len[u]-1,L));
++tag[u].d;
}
void popbk(int u){
tag[u]=bk[u].back().fi;
for(auto x:bk[u].back().se) f[u][x.fi]=x.se;
bk[u].pop_back();
}
}
ll ans=0;
namespace G{
node tag[N];
inline int get(int u,int k)
{
if(k<tag[u].l) return (1ll*g[u][k]*tag[u].a+tag[u].d)%mod;
return (1ll*tag[u].w*tag[u].a+tag[u].d)%mod;
}
void put(int u,int k,int v){g[u][k]=1ll*(v-tag[u].d+mod)*tag[u].iv%mod;}
void dfs(int u,int p)
{
if(len[u]>=L+1) put(u,len[u]-L-1,1);
ans=(ans+ksm(1ll*rb[u]*get(u,len[u]-1)%mod,k))%mod;
if(p) ans=(ans-ksm(1ll*ra[u]*(get(u,len[u]-1)+mod-1)%mod,k)+mod)%mod;
if(!son[u]) return;
vector<int>ch;
int mx=0;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v==p || v==son[u]) continue;
ch.push_back(v);mx=max(mx,min(len[v],L));
}
reverse(ch.begin(),ch.end());
f[u+n]=_new(mx+1);
F::tag[u+n]=o;
F::put(u+n,0,1);
for(int v:ch)
{
F::popbk(u);
g[v]=_new(len[v]);
for(int i=max(len[v]-L-1,0);i<len[v];i++)
{
int p=len[son[u]]-len[v]+i,rp=L-len[v]+i;
if(len[v]==L+i+1) g[v][i]=get(u,p);
else g[v][i]=1ll*get(u,p)*F::get(u,min(len[u]-1,rp))%mod*F::get(u+n,min(mx,rp))%mod;
}
tag[v]=o;
F::merge(u+n,v,min(len[v]-1,L));
dfs(v,u);
}
int v=son[u]; g[v]=g[u],tag[v]=tag[u];
for(int i=max(len[v]-L,0);i<=len[v]+mx-L-1;i++)
{
if(i==tag[v].l) addf(g[v],tag[v]);
put(v,i,1ll*get(v,i)*F::get(u+n,L-len[v]+i)%mod);
}
if(mx<L)
{
int sv=1,iv=1;
for(int v:ch)
{
sv=1ll*sv*val[id[v]]%mod;
iv=1ll*iv*inv[id[v]]%mod;
}
if(!sv)
{
tag[v].l=len[v]+mx-L;
tag[v].w=mod-1ll*tag[v].d*tag[v].iv%mod;
}
else
{
for(int i=max(len[v]-L-1,0);i<=len[v]+mx-L-1;i++)
put(v,i,1ll*get(v,i)*iv%mod);
times(tag[v],sv,iv);
}
}
++tag[v].d;
dfs(v,u);
}
}
int sum[N];
void for_inv()
{
sum[0]=1;
for(int i=1;i<=idn;i++) sum[i]=1ll*sum[i-1]*val[i]%mod;
int invs=ksm(sum[idn]);
for(int i=idn;i;invs=1ll*invs*val[i--]%mod) inv[i]=1ll*sum[i-1]*invs%mod;
}
int main()
{
// freopen("hope4.in","r",stdin);
scanf("%d%d%d",&n,&L,&k);
o=(node){1,1,1,n,0};
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
add(u,v),add(v,u);
}
dfs(1,0);
for_inv();
f[1]=_new(len[1]);F::dfs(1,0);
g[1]=_new(len[1]);G::tag[1]=o;
G::dfs(1,0);
printf("%d\n",ans);
return 0;
}