题目大意:有一棵n个节点的树,给每个节点分配一个非负整数,使得权值和为m,求出所有方案的 标号最小的带权重心的 标号之和。一个点是带权重心当且仅当以它为根的子树中,所有子树的权值和小于等于m除以2下取整。n=200000,m=5000000,对质数取模。

思考:
首先比较容易发现的一点是 带权重心一定组成一条链。对于m是奇数的情况,带权重心仅有一个,因为从某个带权重心移动必然会导致某一个子树的权值和大于m/2。我们枚举成为带权重心的点,并用容斥的方法求出这个点在多少种方案中出现。令$get(x,y)$表示将x个无区别的球放入y个有区别的盒子并且没有数量限制的方案数,那么$get(x,y)=\tbinom{x+y-1}{x}$。注意到对于u号点不合法的情况,当且仅当某一个子树中的权值和大于等于$\frac{m+1}{2}$,于是我们在u号点的方案数,等于$get(m,n)-\sum_{v\in son}{\sum_{i=\frac{m+1}{2}}^{m}{get(i,s_{v})*get(m-i,n-s_{v})}}$,其中$s_{v}$是以u为根时,子树v的大小。

我们换个角度思考后面的和式(这一步直接抛弃了前面的式子):如果我们在子树中先放入$\frac{m+1}{2}$个球,那么剩下的球任意放都能被统计到答案里。因此我们可以先确定第$\frac{m+1}{2}$个球放在了哪个盒子(相当于节点)里,那么方案数就为$get(m,n)-\sum_{v\in son}{\sum_{i=1}^{s_{v}}{get(\frac{m+1}{2}-1,i)*get(\frac{m+1}{2}-1,n-i+1)}}$。这个式子把s拿了出来!因此可以算出它的前缀和,在O(m)的时间内完成。

对于m为偶数的情况,我们先将m-1带入奇数算法,这样算出的是没有标号最小限制的方案和。对于一条链(长度大于等于2),它上面的标号不是最小的点会被记重若干次,这个值只和链的两个端点的子树大小有关,因此考虑点分治来解决。每次算贡献时要计算路径上点权最小的点,把当前分治到的联通块的点权从小到大保存下来就不需要用数据结构了。复杂度O(nlogn+m)。

 

  1 #include<bits/stdc++.h>
  2 #define mod 998244353
  3 using namespace std;
  4 typedef long long int ll;
  5 const int maxn=2E5+5;
  6 const int limit=6E6+5;
  7 int n,m;
  8 ll ans,fac[limit+5],inv[limit+5],preF[maxn];
  9 int size,head[maxn];
 10 int root,sum[maxn],maxp[maxn],fa[maxn],preSum[maxn];
 11 struct edge
 12 {
 13     int to,next;
 14 }E[maxn*2];
 15 inline void add(int u,int v)
 16 {
 17     E[++size].to=v;
 18     E[size].next=head[u];
 19     head[u]=size;
 20 }
 21 inline ll qpow(ll x,ll y)
 22 {
 23     ll ans=1,base=x;
 24     while(y)
 25     {
 26         if(y&1)
 27             ans=ans*base%mod;
 28         base=base*base%mod;
 29         y>>=1;
 30     }
 31     return ans;
 32 }
 33 inline ll C(int x,int y)
 34 {
 35     if(x<y||x<0||y<0)
 36         return 0;
 37     return fac[x]*inv[y]%mod*inv[x-y]%mod;
 38 }
 39 void dfs(int u,int F)
 40 {
 41     fa[u]=F;
 42     preSum[u]=1;
 43     for(int i=head[u];i;i=E[i].next)
 44     {
 45         int v=E[i].to;
 46         if(v==F)
 47             continue;
 48         dfs(v,u);
 49         preSum[u]+=preSum[v];
 50         preF[u]=(preF[u]-C(m/2+preSum[v]-1,m/2))%mod;
 51     }
 52     preF[u]=(preF[u]-C(m/2+n-preSum[u]-1,m/2))%mod;
 53 }
 54 inline void init()
 55 {
 56     fac[0]=1;
 57     for(int i=1;i<=limit;++i)
 58         fac[i]=fac[i-1]*i%mod;
 59     inv[limit]=qpow(fac[limit],mod-2);
 60     for(int i=limit-1;i>=0;--i)
 61         inv[i]=inv[i+1]*(i+1)%mod;
 62 }
 63 inline ll getF(int u,int v)
 64 {
 65     if(v==fa[u])
 66         return (preF[u]+C(m/2+n-preSum[u]-1,m/2)+C(m/2+preSum[u]-1,m/2))%mod;
 67     return (preF[u]+C(m/2+preSum[v]-1,m/2)+C(m/2+n-preSum[v]-1,m/2))%mod;
 68 }
 69 namespace work1
 70 {
 71     ll f[maxn];
 72     inline ll get(int x,int y)
 73     {
 74         return C(x+y-1,x);
 75     }
 76     void init(int u,int F)
 77     {
 78         sum[u]=1;
 79         for(int i=head[u];i;i=E[i].next)
 80         {
 81             int v=E[i].to;
 82             if(v==F)
 83                 continue;
 84             init(v,u);
 85             sum[u]+=sum[v];
 86         }
 87     }
 88     void dfs(int u,int F,int tot)
 89     {
 90         vector<int>wait;
 91         int g=0;
 92         for(int i=head[u];i;i=E[i].next)
 93         {
 94             int v=E[i].to;
 95             if(v==F)
 96                 continue;
 97             g+=sum[v];
 98             wait.push_back(sum[v]);
 99         }
100         wait.push_back(tot);
101         ll s=get(m,n);
102         for(int i=0;i<wait.size();++i)
103             s-=f[wait[i]];
104         s%=mod;
105         ans=(ans+s*u)%mod;
106         for(int i=head[u];i;i=E[i].next)
107         {
108             int v=E[i].to;
109             if(v==F)
110                 continue;
111             dfs(v,u,tot+g-sum[v]+1);
112         }
113     }
114     inline void main(int l)
115     {
116         for(int i=1;i<=n;++i)
117             f[i]=(f[i-1]+get(l-1,i)*get(m-l,n-i+1))%mod;// !!!!!!!
118         init(1,0);
119         dfs(1,0,0);
120     }
121 }
122 namespace work2
123 {
124     bool vis[maxn];
125     int TI,visT[maxn],fa[maxn];
126     ll f[maxn];
127     void get(int u,int F)
128     {
129         fa[u]=F;
130         sum[u]=1;
131         for(int i=head[u];i;i=E[i].next)
132         {
133             int v=E[i].to;
134             if(vis[v]||v==F)
135                 continue;
136             get(v,u);
137             sum[u]+=sum[v];
138         }
139     }
140     void getRoot(int u,int F,int tot)
141     {
142         maxp[u]=0;
143         for(int i=head[u];i;i=E[i].next)
144         {
145             int v=E[i].to;
146             if(v==F||vis[v])
147                 continue;
148             getRoot(v,u,tot);
149             maxp[u]=max(maxp[u],sum[v]);
150         }
151         maxp[u]=max(maxp[u],tot-sum[u]);
152         if(maxp[root]>maxp[u])
153             root=u;
154     }
155     int totC,bel[maxn];
156     ll sumF[maxn];
157     void getFF(int u,int F,int c)
158     {
159         bel[u]=c;
160         f[u]=getF(u,F);
161         for(int i=head[u];i;i=E[i].next)
162         {
163             int v=E[i].to;
164             if(v==F||vis[v])
165                 continue;
166             getFF(v,u,c);
167         }
168         sumF[c]=(sumF[c]+f[u])%mod;
169     }
170     void cut(int u,int F,ll base,ll now)
171     {
172         ans=(ans-base*f[u]%mod*now)%mod;
173         for(int i=head[u];i;i=E[i].next)
174         {
175             int v=E[i].to;
176             if(vis[v]||v==F)
177                 continue;
178             cut(v,u,base,now+v);
179         }
180     }
181     ll fill(int u,int F,int c)
182     {
183         if(visT[u]==c)
184             return 0;
185         visT[u]=c;
186         ll s=f[u];
187         for(int i=head[u];i;i=E[i].next)
188         {
189             int v=E[i].to;
190             if(v==F||vis[v]||visT[v]==c)
191                 continue;
192             s+=fill(v,u,c);
193         }
194         sumF[bel[u]]=(sumF[bel[u]]-f[u])%mod;
195         return s%mod;
196     }
197     vector<int>wait[maxn];
198     int what[maxn];
199     void solve(int u,vector<int>D)
200     {
201         vis[u]=1;
202         ++TI;
203         ll totF=0;
204         sum[u]=0;
205         for(int i=head[u];i;i=E[i].next)
206         {
207             int v=E[i].to;
208             if(vis[v])
209                 continue;
210             ++totC;
211             get(v,u);
212             getFF(v,u,totC);
213             totF=(totF+sumF[totC])%mod;
214             what[totC]=v;
215             sum[u]+=sum[v];
216         }
217         ll s=0;
218         for(int i=head[u];i;i=E[i].next)
219         {
220             int v=E[i].to;
221             if(vis[v])
222                 continue;
223             ll now=getF(u,v);
224             cut(v,u,(totF-sumF[bel[v]]+now)%mod,v);
225             ans=(ans-s*sumF[bel[v]]%mod*u)%mod;
226             ans=(ans-now*u%mod*sumF[bel[v]])%mod;
227             s=(s+sumF[bel[v]])%mod;
228         }
229         int now=n+1;
230         for(int i=0;i<D.size();++i)
231         {
232             int pos=D[i];
233             if(pos==u)
234             {
235                 now=u;
236                 continue;
237             }
238             totF-=sumF[bel[pos]];
239             ll x=fill(pos,fa[pos],TI);
240             ans=(ans+totF*x%mod*min(pos,now))%mod;
241             
242             ans=(ans+x*min(pos,now)%mod*getF(u,what[bel[pos]]))%mod;
243             totF+=sumF[bel[pos]];
244             totF%=mod;
245             wait[bel[pos]].push_back(pos);
246         }
247         for(int i=head[u];i;i=E[i].next)
248         {
249             int v=E[i].to;
250             if(vis[v])
251                 continue;
252             root=0;
253             get(v,u);
254             getRoot(v,u,sum[v]);
255             solve(root,wait[bel[v]]);
256         }
257     }
258     inline void main()
259     {
260         dfs(1,0);
261         maxp[0]=n+1;
262         get(1,0);
263         getRoot(1,0,n);
264         vector<int>D;
265         for(int i=1;i<=n;++i)
266             D.push_back(i);
267         solve(root,D);
268     }
269 }
270 inline void solve()
271 {
272     work1::main(m/2+1);
273     if(m%2==0)
274         work2::main();
275     ans=(ans%mod+mod)%mod;
276     cout<<ans<<endl;
277 }
278 int main()
279 {
280     ios::sync_with_stdio(false);
281     cin>>n>>m;
282     init();
283     for(int i=2;i<=n;++i)
284     {
285         int x,y;
286         cin>>x>>y;
287         add(x,y);
288         add(y,x);
289     }
290     solve();
291     return 0;
292 }
View Code

 

 posted on 2021-03-02 23:01  GreenDuck  阅读(90)  评论(0编辑  收藏  举报