【牛客】路径计数机 (树形dp 前缀和)

题目描述

  有一棵n个点的树和两个整数p, q,求满足以下条件的四元组(a, b, c, d)的个数:
  1.$1\leq a,b,c,d \leq n$
  2.点a到点b的经过的边数为p。
  3.点c到点d的经过的边数为q。
  4.不存在一个点,它既在点a到点b的路径上,又在点c到点d的路径上。

输入描述

  第一行三个整数n,p,q。
  接下来n - 1行,每行两个整数u, v,表示树上存在一个连接点u和点v的边。

输出描述

  输出一个整数,表示答案。
  示例1
  输入
  5 2 1
  1 2
  2 3
  3 4
  2 5
输出
  4
说明
  合法的四元组一共有:
  (1, 5, 3, 4),
  (1, 5, 4, 3),
  (5, 1, 3 ,4),
  (5, 1, 4, 3)。
示例2
  输入
  4 1 1
  1 2
  2 3
  3 4
输出
  8
备注:
  对于前20%的数据,n,p,q≤50。
  对于前40%的数据,n,p,q≤200。
  对于另外10%的数据,p = 2, q = 2。
  对于另外10%的数据,树是一条链。
  对于另外10%的数据,树随机生成。
  对于所有数据1≤n,p,q≤3000,1≤u,v≤n,保证给出的是一棵合法的树。

分析

  我已经弱到连$n^2$枚举路径都不会了

  再一次求助Master_Yi

  这个题只要理顺了就挺好想的了(说得好像我想得出来似的。

  由于不相交的情况不好求,所直接看相交的情况。

  找规律可以发现,如果两条路径相交,其中必有一条路径两个端点的lca在另一条路径上

  所有可以枚举长度为p的路径,减去在以这条路径上的点为端点lca的长度为q的路径

  然后又枚举长度为q的路径,减去在以这条路径上的点为端点lca的长度为p的路径

  发现当路径端点lca相同的情况被多算了一次,于是就加回来

  那么如何实现呢?

  设sq[x],sp[x]分别表示以x为端点lca,长度为q和长度为p的路径条数

  枚举路径是$n^2$的,如果不能优化的话,那么我们现在需要的是快速求出一条路径上的sq或sp和

  现在要求的是一条路径的和,一个一个找点肯定会T,所以可以预处理一些东西能让我们能够拼凑出答案

  如果预处理从根到某个节点x上的路径的sq之和与sp之和,记为ssq[x]与spp[x]

  那么以i,j为两端点的路径中sq和sp之和就为ssq[i]+ssq[j]-ssq[lca[i][j]]-ssq[fa[lca[i][j]]]与ssp[i]+ssp[j]-ssp[lca[i][j]]-ssp[fa[lca[i][j]]]

  感觉有些与前缀和类似。。。。。。

  这样就可以O(1)计算了,总的时间复杂度就为O(n^2)

  跟Master_Yi几乎一样的Code

#include<cstdio>
#include<algorithm>
using namespace std;
const int maxn=3005;
int fa[maxn],dep[maxn],ori[maxn],f[maxn][maxn],lca[maxn][maxn];
int n,p,q,ecnt,Sp,Sq,v[maxn<<1],nx[maxn<<1],sq[maxn],sp[maxn],vis[maxn],info[maxn];
int find(int x){return !ori[x]?x:ori[x]=find(ori[x]);}
void add(int u1,int v1){nx[++ecnt]=info[u1];info[u1]=ecnt;v[ecnt]=v1;}
void dfs1(int x,int fa)
{
    f[x][0]=1;vis[x]=1;for(int i=1;i<=n;i++)if(vis[i])lca[x][i]=lca[i][x]=find(i);
    for(int e=info[x];e;e=nx[e])if(v[e]!=fa)
    {
        dfs1(v[e],x);
        for(int i=0;i<q;i++)sq[x]+=f[x][i]*f[v[e]][q-i-1];
        for(int i=0;i<p;i++)sp[x]+=f[x][i]*f[v[e]][p-i-1];
        for(int i=0;i<max(p,q);i++)f[x][i+1]+=f[v[e]][i];
    }
    ori[x]=fa;Sq+=sq[x];Sp+=sp[x];
}
void dfs2(int x,int f){sq[x]+=sq[f];sp[x]+=sp[f];dep[x]=dep[fa[x]=f]+1;for(int e=info[x];e;e=nx[e])if(v[e]!=f)dfs2(v[e],x);}
int main()
{
    scanf("%d%d%d",&n,&p,&q);
    for(int i=1,u1,v1;i<n;i++)scanf("%d%d",&u1,&v1),add(u1,v1),add(v1,u1);
    dfs1(1,0);dfs2(1,0);long long ans=p!=q?1ll*Sp*Sq:1ll*Sp*(Sq-1)/2;
    for(int i=1;i<=n;i++)for(int j=i+1;j<=n;j++)
    {
        int len=dep[i]+dep[j]-2*dep[lca[i][j]];
        if(len==p&&len==q){ans-=sq[i]+sq[j]-sq[lca[i][j]]-sq[fa[lca[i][j]]]-1;continue;}
        if(len==p)ans-=sq[i]+sq[j]-sq[lca[i][j]]-sq[fa[lca[i][j]]];
        if(len==q)ans-=sp[i]+sp[j]-sp[lca[i][j]]-sp[fa[lca[i][j]]];
    }
    for(int i=1;i<=n;i++)
    if(p==q)ans+=1ll*(sp[i]-sp[fa[i]])*(sq[i]-sq[fa[i]]-1)/2;
    else ans+=1ll*(sp[i]-sp[fa[i]])*(sq[i]-sq[fa[i]]);
    printf("%lld\n",p==q?ans<<3:ans<<2);
}

 

 

 

posted @ 2019-11-06 10:57  散樗  阅读(468)  评论(0编辑  收藏  举报