树上的等差数列 [树形dp]

树上的等差数列

题目描述

给定一棵包含 \(N\) 个节点的无根树,节点编号 \(1\to N\) 。其中每个节点都具有一个权值,第 \(i\) 个节点的权值是 \(A_i\)

\(Hi\) 希望你能找到树上的一条最长路径,满足沿着路径经过的节点的权值序列恰好构成等差数列。

输入格式

第一行包含一个整数 \(N\)

第二行包含 \(N\) 个整数 \(A_1, A_2, ... A_N\)

以下 \(N-1\) 行,每行包含两个整数 \(U\)\(V\) ,代表节点 \(U\)\(V\) 之间有一条边相连。

输出格式

最长等差数列路径的长度

样例

样例输入

7  
3 2 4 5 6 7 5  
1 2  
1 3  
2 7  
3 4  
3 5  
3 6

样例输出

4

数据范围与提示

对于 \(50\%\) 的数据,\(1 \leqslant N \leqslant 1000\)

对于 \(100\%\) 的数据,\(1 \leqslant N \leqslant 100000, 0 \leqslant A_i \leqslant 100000, 1 \leqslant U, V \leqslant N\)

分析

树形 \(dp\) 好题。

因为要求的是最长的等差序列,根节点不同,答案也可能不同,所以 \(dp\) 的状态转移就定义为 \(f[i][j]\) 表示 \(i\) 节点为根,公差为 \(j\) 时的最长的等差数列,不包括自己。那么我们就可以愉快的 \(dfs\) 来进行转移了。

我们记录一下他自己和他的父亲,避免出现死循环,每一次先 \(dfs\) 到儿子,递归上来,然后就处理出来了公差为 \(\Delta\) 的以儿子为根的所有长度,这时候我们只需要判断一下此时的 \(\Delta\) 值是否为 \(0\)。如果是,那么 \(ans\) 的转移应该是:

\[ans = max(ans,f[x][0] + f[son[x]][0] + 2) \]

因为此时 \(f[x][0]\) 存储的是其他儿子上最长链,所以需要加上当前儿子的最长链,因为我们的数组不保存自己,所以要加 \(2\)

其他情况就是直接更新 \(ans\) ,他的答案应该是 \(f[x][d] + f[x][-d] + 1\) ,因为他的父亲那里也可能会有链,公差为 \(-d\) 就是那个链,由于负数下标的问题,我们利用 \(map\) 来存储,然后轻松解决此题。

代码

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<map>
#define re register
using namespace std;
const int maxn = 1e5+10;
map <int,int> mp[maxn];
struct Node{
	int v,next;
}e[maxn<<1];
int w[maxn];
int ans = 0;
int head[maxn],tot;
void Add(int x,int y){//建边
	e[++tot].v = y;
	e[tot].next = head[x];
	head[x] = tot;
}
inline int read(){//快读
	int s = 0,f = 1;
	char ch = getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){s=s*10+ch-'0';ch=getchar();}
	return s * f;
}
inline void DP(int x,int fa){
	for(int i=head[x];i;i=e[i].next){
		int v = e[i].v;
		if(v == fa)continue;//避免死循环
		int d = w[v] - w[x];//计算公差
		DP(v,x);
		if(!d){//公差为0的情况
			ans = max(ans,mp[x][0] + mp[v][0] + 2);
			mp[x][0] = max(mp[x][0],mp[v][0] + 1);
		}
		else{//公差不为0
			mp[x][d] = max(mp[x][d],mp[v][d] + 1);
			ans = max(ans,mp[x][d] + mp[x][-d] + 1);
		}
	}
}

int main(){
	freopen("C.in","r",stdin);
	freopen("C.out","w",stdout);
	int n =read();
	for(re int i = 1;i<=n;++i){w[i]=read();}
	for(re int i = 1;i< n;++i){
		int x = read(),y = read();
		Add(x,y);
		Add(y,x);
	}
	DP(1,0);
	printf("%d\n",ans);
}
posted @ 2020-08-14 21:43  Vocanda  阅读(229)  评论(0编辑  收藏  举报