[dp][递归] Jzoj P4211 送你一棵圣诞树
题解
- 题目大意:给定了n棵树,每次会奖两棵树中的x,y之间连一条len的边,也就是将两棵树合并,问每棵树上两两点对的最短路径和
- 设calc1(x,y)为x这棵树里,其他点到y的最短路径和,calc2(x,y,k)为x这棵树里,y到k的最短路径长度
- 设f[i]为第i棵数的答案,f[i]=f[i.left]+f[i.right]+calc1(i.left,i.c)*i.right.size+calc1(i.right,i.d)*i.left.size+i.len*i.left.size*i.right.size
- calc1(x,y)中可以分成两种情况,①在x的左子树②在x的右子树
- 当在x的左子树时calc1(x,y)=calc1(x.left,y)+(calc2(y,x.left,x.c)+x.len)*x.right.size+calc1(x.right,x.d),但在右子树里也是同理
- calc2(x,y,z)也可以想上面分成两种情况,①y和k在同一棵子树里②y和k不在同一棵子树里
- 当在它们在同一颗子树里,直接往下递归;当它们不在同一棵子树里时,calc2(x,y,z)=calc2(x.left,y,x.c)+calc2(x.right,z,x.d)+x.len
- 这题比较恶心,还要用map优化
代码
1 #include<cstdio> 2 #include<iostream> 3 #include<cstring> 4 #include<map> 5 #define ll long long 6 using namespace std; 7 const ll N=70,mo=1e9+7; 8 ll T,n,f[N]; 9 typedef pair<ll,ll> node; 10 struct edge{ ll a,b,size,c,d,v; }p[N]; 11 map <node,ll> Q[N]; 12 ll calc2(ll k,ll x,ll y) 13 { 14 if (x>y) swap(x,y); 15 if (k==0||x==y) return 0; 16 if (x<p[p[k].a].size&&y<p[p[k].a].size) return calc2(p[k].a,x,y); 17 else if (x>=p[p[k].a].size&&y>=p[p[k].a].size) return calc2(p[k].b,x-p[p[k].a].size,y-p[p[k].a].size); 18 else 19 { 20 node q=make_pair(x,y); 21 if (Q[k].find(q)!=Q[k].end()) return Q[k][q]; 22 return Q[k][q]=(calc2(p[k].a,p[k].c,x)+calc2(p[k].b,p[k].d,y-p[p[k].a].size)+p[k].v)%mo; 23 } 24 } 25 ll calc1(ll k,ll x) 26 { 27 ll r=0; node q=make_pair(x,0); 28 if (Q[k].find(q)!=Q[k].end()) return Q[k][q]; 29 if (k==0) return 0; 30 if (x<p[p[k].a].size) r=(calc1(p[k].a,x)+calc1(p[k].b,p[k].d))%mo+(((p[k].v+calc2(p[k].a,p[k].c,x))%mo)*(p[p[k].b].size%mo))%mo; 31 else r=(calc1(p[k].a,p[k].c)+calc1(p[k].b,x-p[p[k].a].size))%mo+(((p[k].v+calc2(p[k].b,p[k].d,x-p[p[k].a].size))%mo)*(p[p[k].a].size%mo)%mo); 32 return (Q[k][q]=r)%mo; 33 } 34 int main() 35 { 36 freopen("data.in","r",stdin); 37 scanf("%d",&T); 38 while (T--) 39 { 40 scanf("%d",&n),memset(p,0,sizeof(p)),p[0].size=1; 41 for (int i=1;i<=n;i++) scanf("%lld%lld%lld%lld%lld",&p[i].a,&p[i].b,&p[i].c,&p[i].d,&p[i].v),p[i].size=p[p[i].a].size+p[p[i].b].size,Q[i].clear(); 42 for (int i=1;i<=n;i++) 43 f[i]=((f[p[i].a]+f[p[i].b])%mo+(calc1(p[i].a,p[i].c)*(p[p[i].b].size%mo)+(calc1(p[i].b,p[i].d)*(p[p[i].a].size%mo))%mo 44 +(((p[i].v%mo)*(p[p[i].a].size%mo))%mo*(p[p[i].b].size%mo))%mo)%mo)%mo, 45 printf("%lld\n",f[i]); 46 } 47 }