【牛客】路径计数机 (树形dp 前缀和)
题目描述
有一棵n个点的树和两个整数p, q,求满足以下条件的四元组(a, b, c, d)的个数:
1.$1\leq a,b,c,d \leq n$
2.点a到点b的经过的边数为p。
3.点c到点d的经过的边数为q。
4.不存在一个点,它既在点a到点b的路径上,又在点c到点d的路径上。
输入描述
第一行三个整数n,p,q。
接下来n - 1行,每行两个整数u, v,表示树上存在一个连接点u和点v的边。
输出描述
输出一个整数,表示答案。
示例1
输入
5 2 1
1 2
2 3
3 4
2 5
输出
4
说明
合法的四元组一共有:
(1, 5, 3, 4),
(1, 5, 4, 3),
(5, 1, 3 ,4),
(5, 1, 4, 3)。
示例2
输入
4 1 1
1 2
2 3
3 4
输出
8
备注:
对于前20%的数据,n,p,q≤50。
对于前40%的数据,n,p,q≤200。
对于另外10%的数据,p = 2, q = 2。
对于另外10%的数据,树是一条链。
对于另外10%的数据,树随机生成。
对于所有数据1≤n,p,q≤3000,1≤u,v≤n,保证给出的是一棵合法的树。
分析
我已经弱到连$n^2$枚举路径都不会了
再一次求助Master_Yi
这个题只要理顺了就挺好想的了(说得好像我想得出来似的。
由于不相交的情况不好求,所直接看相交的情况。
找规律可以发现,如果两条路径相交,其中必有一条路径两个端点的lca在另一条路径上
所有可以枚举长度为p的路径,减去在以这条路径上的点为端点lca的长度为q的路径
然后又枚举长度为q的路径,减去在以这条路径上的点为端点lca的长度为p的路径
发现当路径端点lca相同的情况被多算了一次,于是就加回来
那么如何实现呢?
设sq[x],sp[x]分别表示以x为端点lca,长度为q和长度为p的路径条数
枚举路径是$n^2$的,如果不能优化的话,那么我们现在需要的是快速求出一条路径上的sq或sp和
现在要求的是一条路径的和,一个一个找点肯定会T,所以可以预处理一些东西能让我们能够拼凑出答案
如果预处理从根到某个节点x上的路径的sq之和与sp之和,记为ssq[x]与spp[x]
那么以i,j为两端点的路径中sq和sp之和就为ssq[i]+ssq[j]-ssq[lca[i][j]]-ssq[fa[lca[i][j]]]与ssp[i]+ssp[j]-ssp[lca[i][j]]-ssp[fa[lca[i][j]]]
感觉有些与前缀和类似。。。。。。
这样就可以O(1)计算了,总的时间复杂度就为O(n^2)
跟Master_Yi几乎一样的Code
#include<cstdio>
#include<algorithm>
using namespace std;
const int maxn=3005;
int fa[maxn],dep[maxn],ori[maxn],f[maxn][maxn],lca[maxn][maxn];
int n,p,q,ecnt,Sp,Sq,v[maxn<<1],nx[maxn<<1],sq[maxn],sp[maxn],vis[maxn],info[maxn];
int find(int x){return !ori[x]?x:ori[x]=find(ori[x]);}
void add(int u1,int v1){nx[++ecnt]=info[u1];info[u1]=ecnt;v[ecnt]=v1;}
void dfs1(int x,int fa)
{
f[x][0]=1;vis[x]=1;for(int i=1;i<=n;i++)if(vis[i])lca[x][i]=lca[i][x]=find(i);
for(int e=info[x];e;e=nx[e])if(v[e]!=fa)
{
dfs1(v[e],x);
for(int i=0;i<q;i++)sq[x]+=f[x][i]*f[v[e]][q-i-1];
for(int i=0;i<p;i++)sp[x]+=f[x][i]*f[v[e]][p-i-1];
for(int i=0;i<max(p,q);i++)f[x][i+1]+=f[v[e]][i];
}
ori[x]=fa;Sq+=sq[x];Sp+=sp[x];
}
void dfs2(int x,int f){sq[x]+=sq[f];sp[x]+=sp[f];dep[x]=dep[fa[x]=f]+1;for(int e=info[x];e;e=nx[e])if(v[e]!=f)dfs2(v[e],x);}
int main()
{
scanf("%d%d%d",&n,&p,&q);
for(int i=1,u1,v1;i<n;i++)scanf("%d%d",&u1,&v1),add(u1,v1),add(v1,u1);
dfs1(1,0);dfs2(1,0);long long ans=p!=q?1ll*Sp*Sq:1ll*Sp*(Sq-1)/2;
for(int i=1;i<=n;i++)for(int j=i+1;j<=n;j++)
{
int len=dep[i]+dep[j]-2*dep[lca[i][j]];
if(len==p&&len==q){ans-=sq[i]+sq[j]-sq[lca[i][j]]-sq[fa[lca[i][j]]]-1;continue;}
if(len==p)ans-=sq[i]+sq[j]-sq[lca[i][j]]-sq[fa[lca[i][j]]];
if(len==q)ans-=sp[i]+sp[j]-sp[lca[i][j]]-sp[fa[lca[i][j]]];
}
for(int i=1;i<=n;i++)
if(p==q)ans+=1ll*(sp[i]-sp[fa[i]])*(sq[i]-sq[fa[i]]-1)/2;
else ans+=1ll*(sp[i]-sp[fa[i]])*(sq[i]-sq[fa[i]]);
printf("%lld\n",p==q?ans<<3:ans<<2);
}