Codeforces 771C
我的树形dp果然是渣。。。
题意:给一棵树,共n(0<n<=15e4)个节点,可在树上进行跳跃,每次跳的最大距离为k(0<k<=5),定义f(s,t)为(dis(s,t)+k)/k,问Σf(s,t),s<t。
解题思路:
显然是树形dp,问题在于怎么构建状态。
最简单想到的就是,每到一个节点u,记录其子树中与其距离为d的的节点的数目,即(dis,cnt)对,则答案分两种情况,u到其子节点和以及子节点经过u到子节点,问题变得很简单,计算也不难——但问题在于极端情况下——比如树退化成链,时空复杂度都将高到无法忍受,于是卡在了这里……
然后看了别人的题解,发现不需要构建(dis,cnt)对,而是构建(dis%k,cnt)对——这样复杂度枚举的最高复杂度也就只是k^2而不是dis^2,,,啊感觉自己宛若一个zz。
定义,sz(u,i)表示与节点u的距离%k为 i 的节点数目,dp(u,i)表示从u到sz(u,i)中节点的跳数和。
状态转移如下:
(1)sz(u,(i+1)%k)+=sz(v,i)
(2)dp(u,(i+1)%k)+=dp(v,i) 0<i<k
(3)dp(u,1%k)+=dp(v,0)+sz(v,0)
显然,与u的子节点v的距离%k大于0的,到v所需跳数与到u所需跳数是相同的(式子(2))。否则跳数需要+1,有sz(v,0)个点,故再加上sz(v,0)(式子(3))。
接下来是统计答案,分两种情况,一个是u到其子节点,另一个是u到子节点到其它子节点(子节点间的最近公共祖先节点为u)。
对于第一种,直接res+=Σdp[u][i]即可。
对于第二种,将之前已经遍历过的子树节点都合并到dp[u][i]与sz[u][i]中,则与新的子树节点v合并时,枚举k1、k2,依次为已遍历过的子树节点与u的距离%k=k1,以及与v的子树中与u的距离%k为k2的节点(语文不好这段贼拗口……k^2的枚举就是这里了),分情况讨论:
如果k1为0,则dp[u][k1]中则为恰好到达u的跳数(即每一跳距离都是k),则从sz[u][k1]中的节点到达 v的子树中的节点 所需跳数 恰好为二者相加之和,不需额外处理;如果k2为0,同理;当k1与k2都不为0时,考虑k1+k2<=k,则此时说明多计算了一跳,因此需要减去;当k1+k2>k时,恰好符合。
综上,需要特判的也只有(k1&&k2&&k1+k2<=k)这种情况。
代码如下:
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> using namespace std; typedef long long ll; #define sqr(x) ((x)*(x)) const int N=2e5+10; int head[N],nxt[N<<1],to[N<<1],cnt; int n,k,a,b; ll sz[N][6],dp[N][6],res; void init(){ memset(head,-1,sizeof(head)); res=cnt=0; memset(sz,0,sizeof(sz)); memset(dp,0,sizeof(dp)); } void addEdge(int u,int v){ nxt[cnt]=head[u]; to[cnt]=v; head[u]=cnt++; } void dfs(int u,int pre){ ll tsz[6],tdp[6]; //用以暂时保存从u到新的v子树节点的数据 for(int e=head[u];~e;e=nxt[e]){ int v=to[e]; if(v==pre) continue; dfs(v,u); memset(tsz,0,sizeof(tsz)); memset(tdp,0,sizeof(tdp)); for(int i=0;i<k;i++) tsz[(i+1)%k]+=sz[v][i]; for(int i=1;i<k;i++) tdp[(i+1)%k]+=dp[v][i]; tdp[1%k]+=dp[v][0]+sz[v][0]; for(int k1=0;k1<k;k1++){ for(int k2=0;k2<k;k2++){ res+=dp[u][k1]*tsz[k2]+sz[u][k1]*tdp[k2]; if(k1&&k2&&k1+k2<=k) res-=sz[u][k1]*tsz[k2]; } } //将v的子树节点情况合并到u下 for(int i=0;i<k;i++) dp[u][i]+=tdp[i],sz[u][i]+=tsz[i]; } for(int i=0;i<k;i++) res+=dp[u][i]; sz[u][0]++; } int main(){ //freopen("in.txt","r",stdin); while(~scanf("%d%d",&n,&k)){ init(); for(int i=1;i<n;i++){ scanf("%d%d",&a,&b); addEdge(a,b); addEdge(b,a); } dfs(1,0); printf("%I64d\n",res); } return 0; }
参考题解:http://www.cnblogs.com/AOQNRMGYXLMV/p/6579771.html