hdu 6091 Rikka with Match
题
OvO http://acm.hdu.edu.cn/showproblem.php?pid=6091
( 2017 Multi-University Training Contest - Team 6 - 1007)
解
记 f[i][j]表示,以i为根的子树的所有子图(包含子树所有节点,删掉一些边得到的子图)中,符合下列条件的子图的个数
1. 记子图最大匹配为h1,子图去掉与i节点相连的边后的最大匹配为h2,满足 h1-h2=0
2. h1%m=j
记g[i][j]表示,以i为根的子树的所有子图(包含子树所有节点,删掉一些边得到的子图)中,符合下列条件的子图的个数
1. 记子图最大匹配为h1,子图去掉与i节点相连的边后的最大匹配为h2,满足 h1-h2=1
2. h1%m=j
则对于每个根节点 p 遍历其后继,回溯计算答案,当遍历到后继 q 时,有下列递推式成立
这个可以由一个经典树形DP(求最大匹配)联想到(反正我是没联想到),
复杂度的话,因为如果要使一个节点k的size变成m的话,那么这个子树的大小也就有了m的大小,所以如果一个节点他的后继的size全为m的话,那么这个节点的后继值最多只有n/m个。
所以复杂度大概似乎好像貌似就O(mn)吧,
(思路来自题解)
#include <iostream> #include <cstring> #include <cstdio> #include <cmath> #include <algorithm> using namespace std; const int M=5e4+44; const int N=244; const int mod=998244353; struct node{ int u,v,d; int next; } edge[2*M]; int num; int head[M]; int n,m; int f[M][N],g[M][N]; int sz[M]; int F[N<<1],G[N<<1]; void addedge(int u,int v,int d) { edge[num].u=u; edge[num].v=v; edge[num].d=d; edge[num].next=head[u]; head[u]=num++; } void init() { num=0; memset (head,-1,sizeof(head)); } int inc(int x,int y) { x+=y; if(x>m) return m; return x; } void deal(int rt,int v) { int i,j,siz=inc(sz[rt],sz[v]); memset(F,0,sizeof(F)); memset(G,0,sizeof(G)); for(i=0;i<=sz[rt];i++) for(j=0;j<=sz[v];j++) { // printf("i: %d j: %d\n",i,j); F[i+j]=(0ll+F[i+j]+1ll*f[rt][i]*f[v][j]+2ll*f[rt][i]*g[v][j])%mod; G[i+j]=(0ll+G[i+j]+2ll*g[rt][i]*g[v][j]+2ll*g[rt][i]*f[v][j])%mod; G[i+j+1]=(0ll+G[i+j+1]+1ll*f[rt][i]*f[v][j])%mod; } // printf("rt: %d v: %d\n",rt,v); // printf(" F:\n"); // for(i=0;i<2*m+3;i++) // printf("%d ",F[i]); // printf("\n G:\n"); // for(i=0;i<2*m+3;i++) // printf("%d ",G[i]); // printf("\n"); sz[rt]=siz; for(i=0;i<m;i++) { f[rt][i]=(0ll+F[i]+F[i+m])%mod; g[rt][i]=(0ll+G[i]+G[i+m])%mod; } } void dfs(int rt,int pa) { int i,j,v; sz[rt]=0; memset(f[rt],0,sizeof(f[rt])); memset(g[rt],0,sizeof(g[rt])); f[rt][0]=1; for(i=head[rt];i!=-1;i=edge[i].next) { v=edge[i].v; if(v==pa) continue; dfs(v,rt); deal(rt,v); } sz[rt]=inc(sz[rt],1); } void solve() { dfs(1,-1); } int main() { int i,j,cas,u,v; scanf("%d",&cas); while(cas--) { init(); scanf("%d%d",&n,&m); for(i=1;i<n;i++) { scanf("%d%d",&u,&v); addedge(u,v,1); addedge(v,u,1); } solve(); printf("%d\n",(0ll+f[1][0]+g[1][0])%mod); } return 0; }