[ARC125F] Tree Degree Subset Sum
Solution
首先这个树的限制几乎没用,我们可以先把每个点度数 \(-1\),然后总的度数就是 \(n-2\) ,设 \(z\) 为度数为 \(0\) 的点的个数。
可以看出,这个问题的麻烦之处就在于对于一个度数和还要求出有多少个满足的大小,而这个似乎只能 \(\Theta(n^2\log n)\) dp。
不过,我们稍作观察之后发现一个性质,即是假设设 \(L(x),R(x)\) 表示能构成度数和为 \(x\) 的所需的最少点数和最多点数,那么 \([L(x),R(x)]\) 的点数都是合法的。
考虑如何证明。发现如果能够证明 \(R(x)-L(x)\le 2\times z-1\),那么这个东西就是合理的。因为 \(L(x)\) 一定是一个 \(0\) 度数都没有选的,所以它能覆盖 \([L(x),L(x)+z]\),\(R(x)\) 一定选了所有的 \(0\) 度数点,所以它能覆盖 \([R(x)-z,R(x)]\)。
注意到对于任意 \(k\) 个数构成了度数和为 \(c\) 的情况,存在 \(-z\le k-c\le z-2\) 。因为最小的时候一定是全选 \(0\),最大的时候一定是全选非 \(0\) 点,也即是:
\[\sum d-\sum [d>0]=(n-2)-(n-z)=z-2
\]
所以我们可以得到 \(-z\le L(x)-x\le z-2,-z\le R(x)-x\le z-2\),所以可以得到 \(R(x)-L(x)\le 2\times z-2\)。
然后因为不同的度数个数只有 \(\sqrt n\) 个,所以我们可以二进制分组,复杂度即为 \(\Theta(n\sqrt n\log n)\)。
Code
#include <bits/stdc++.h>
using namespace std;
#define Int register int
#define ll long long
#define MAXN 200005
template <typename T> inline void read (T &x){x = 0;int f = 1;char c = getchar ();while (!isdigit (c)) f *= (c == '-' ? -1 : 1),c = getchar ();while (isdigit (c)) x = x * 10 + c - '0',c = getchar ();x *= f;}
template <typename T,typename ... Args> inline void read (T &x,Args& ... args){read (x),read (args...);}
template <typename T> inline void write (T x){if (x < 0) x = -x;if (x > 9) write (x / 10);putchar (x % 10 + '0');}
template <typename T> inline void chkmin (T &a,T b){a = min (a,b);}
template <typename T> inline void chkmax (T &a,T b){a = max (a,b);}
int n,w[MAXN],deg[MAXN],cnt[MAXN],mxv[MAXN],miv[MAXN];
signed main(){
read (n);
for (Int i = 2,u,v;i <= n;++ i)
read (u,v),deg[u] ++,deg[v] ++;
int z = 0;
for (Int u = 1;u <= n;++ u) deg[u] --,cnt[deg[u]] ++,z += (deg[u] == 0);
cout << endl;
memset (mxv,0xcf,sizeof (mxv)),mxv[0] = 0;
for (Int i = 1;i <= n;++ i) if (cnt[i]){
int x = cnt[i];
int t = 0,rst = x;
for (Int d = 0;(1 << d) <= rst;++ d) w[++ t] = 1 << d,rst -= w[t];
if (rst) w[++ t] = rst;
for (Int k = 1;k <= t;++ k)
for (Int j = n;j >= i * w[k];-- j)
chkmax (mxv[j],mxv[j - i * w[k]] + w[k]);
}
memset (miv,0x3f,sizeof (miv)),miv[0] = 0;
for (Int i = 1;i <= n;++ i) if (cnt[i]){
int x = cnt[i];
int t = 0,rst = x;
for (Int d = 0;(1 << d) <= rst;++ d) w[++ t] = 1 << d,rst -= w[t];
if (rst) w[++ t] = rst;
for (Int k = 1;k <= t;++ k)
for (Int j = n;j >= i * w[k];-- j)
chkmin (miv[j],miv[j - i * w[k]] + w[k]);
}
ll ans = 0;
for (Int i = 0;i <= n;++ i) if (miv[i] <= mxv[i]) ans += mxv[i] - miv[i] + 1 + z;
write (ans),putchar ('\n');
return 0;
}