CF1517F Reunion
一、题目
二、解法
直接统计难得很,很容易想到把问题转化成统计半径 \(\geq r\) 的圆的个数。
现在还是很不好做,因为限制是 存在一个点是的周围白点构成半径>=r的圆
,很容易算重。
做第二步转化:对于所有黑点周围<=r的点的并集不是所有点
,这是一个染色问题,可以 \(dp\) 了,考虑状态中记录黑点往上覆盖的最远距离,或者是子树内最深的还没有被覆盖的点。这两个信息只有一个有效,因为如果子树内有没有被覆盖的点,设点 \(x\) 能覆盖这个点,那么点 \(x\) 的覆盖范围一定是子树内黑点往上覆盖范围的超集。
可以压在一个 \(dp\) 数组里面写,设 \(dp[u][i]\) 表示在子树 \(u\) 内,\(i\geq 0\) 则子树内黑点往上覆盖的最远距离是 \(i\),\(i<0\) 则表示子树内最深没覆盖点的深度是 \(i+1\),转移讨论 \(4\) 种情况即可,因为 \(i\) 的有效取值 \(\leq siz[u]\),所以时间复杂度是 \(O(n^2)\) 的,在外面统计所有 \(i<0\) 的情况即可,总时间复杂度 \(O(n^3)\)
三、总结
转化的艺术,第一步转化是常见套路:差分。
第二步转化就很神奇了,我在这里断言计数问题中所有限制问题优于存在限制问题,转化完就不会算重了。
\(dp\) 的设计可以想到:小范围讨论除去无效状态,这个 \(\tt trick\) 在阿七那题里也有。
#include <cstdio>
#include <vector>
#include <cstring>
using namespace std;
#define int long long
#define pii pair<int,int>
#define make make_pair
const int M = 305;
const int sh = 300;
const int MOD = 998244353;
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,r,tot,ans,f[M],dp[M][M<<1];
struct edge
{
int v,next;
}e[2*M];
void add(int &x,int y)
{
x=(x+y)%MOD;
}
void adds(int u,int v)
{
vector<pii> v1,v2;
for(int i=0;i<=2*sh;i++)
{
if(dp[u][i])
v1.push_back(make(i-sh,dp[u][i])),dp[u][i]=0;
if(dp[v][i])
v2.push_back(make(i-sh,dp[v][i]));
}
for(int x=0;x<v1.size();x++)
for(int y=0;y<v2.size();y++)
{
int i=v1[x].first,j=v2[y].first;
int w=v1[x].second*v2[y].second%MOD;
if(i<0 && j<0) add(dp[u][min(i,j-1)+sh],w);
if(i>=0 && j<0) add(dp[u][(i+j>=0?i:j-1)+sh],w);
if(i<0 && j>=0) add(dp[u][(i+j>=0?j-1:i)+sh],w);
if(i>=0 && j>=0) add(dp[u][max(i,j-1)+sh],w);
}
}
void dfs(int u,int fa)
{
dp[u][r+sh]=dp[u][-1+sh]=1;
for(int t=f[u];t;t=e[t].next)
{
int v=e[t].v;
if(v==fa) continue;
dfs(v,u);
adds(u,v);
}
}
signed main()
{
n=read();
for(int i=1;i<n;i++)
{
int u=read(),v=read();
e[++tot]=edge{v,f[u]},f[u]=tot;
e[++tot]=edge{u,f[v]},f[v]=tot;
}
for(r=1;r<=n;r++)
{
memset(dp,0,sizeof dp);
dfs(1,0);
for(int i=0;i<sh;i++) add(ans,dp[1][i]);
}
int inv=(MOD+1)/2;
ans=(ans-1+MOD)%MOD;//all is busy
for(int i=1;i<=n;i++) ans=ans*inv%MOD;
printf("%lld\n",ans);
}