[bzoj 4543] Hotel加强版
题意
给定一棵有n个点的树,有三元组$(a,b,c)$满足a,b,c两两距离相等,求这样的三元组的个数。
n<=100000
题解
初步思考
这种题目有个经典解法就是固定两点求中点。但很遗憾这种做法是$O(n^2)$的。
所以我们用树形dp解决这个问题。
考虑从三点lca处统计贡献,这样比较方便。
状态设计
我们考虑逐步满足法,先统计一个点,再用一个点的数据计算两个点都满足的数据,再用一个点和两个点的数据计算三个点都满足的数据。
于是,我们用$f[i][j]$表示在i的子树内与i距离为j的点的个数,
用$g[i][j]$表示在i的子树内有两个点,且他们再加上一个i子树外的与i距离为j的点就满足题目要求。
之所以这样设置是为了方便将g和f拼在一起,且不重不漏地统计答案。
转移方程
那么怎样转移计算g[i][j]的值呢?
因为第二项是j,由定义得现在还需要一条长度为j的边。
考虑一个个枚举i的儿子,统计贡献。
第一种情况,这两条边,一条在已经枚举过的儿子里找,一条在新加进来的找。
那么i就是这两条边的交点,可得两条边长度为i.
下图中,红色表示现在连接的边,黄色表示需要的边,绿色表示将要配对成一个三元组的边。
这种情况可以用$f[i][j]*f[son][j-1]$描述,其中son表示新加进来的儿子。
另外一种情况,两条边都在son里找。
那么就是在son原有的二元组的基础上再加上i--son这一条,因为现在还需要一条长度为j的边,那么在连接之前son则需要一条长度为j+1的边。
这部分可以用$g[son][j+1]$描述
那么,整个g的方程就呼之欲出了
$g[i][j]+=g[son][j+1]+f[i][j]*f[son][j-1]$
f的方程也很简单
$f[id][j]+=f[to][j-1]$
答案的统计也可以列出来了:
$ans+=g[to][j+1]*f[id][j]+g[id][j]*f[to][j-1]$
注意下图的情况要单独计算
可以用$ans+=g[id][0]$来统计。
注意先转移g,再f,再ans。
优化转移
但是,这样仍然是$O(n^2)$
注意到第一个儿子对i的贡献是
$f[x][i]+=f[u][i-1]$
$ g[x][i]+=g[u][i+1]$
也就是说第一个儿子可以$O(1)$转移。
那么我们可以长链剖分,把最深的儿子放在第一位(每个儿子转移复杂度为他的链的长度)
这样,每条链只会在顶部被计算一次,则复杂度为$O(n)$
剩下的就是空间复杂度的问题,我们可以像重链剖分一样把每条链的节点放在一起,因为f和g数组是直接在深儿子的数组上位移而来。所以每条链可以共用一段空间,只不过每个点的起始位不同。对于
存储如下
这样,空间也是$O(n)$的。这样就可以开始敲代码了。
代码
#include <iostream> #include <cstdio> #include <vector> using namespace std; #define N 1000001 #define int long long vector<int> vec[N]; int maxd[N],dson[N],dep[N]; int g[N*5],f[N*5],dfn[N],cnt,top[N],pos[N],cnt2; void get_maxd(int id,int from) { maxd[id]=1; dep[id]=dep[from]+1; for(int i=0;i<vec[id].size();i++) { int to=vec[id][i]; if(to==from) continue; get_maxd(to,id); if(maxd[to]+1>maxd[id]) dson[id]=to; maxd[id]=max(maxd[to]+1,maxd[id]); } } void get_dfn(int id,int from,int root) { dfn[id]=++cnt; top[id]=root; if(dson[id]) get_dfn(dson[id],id,root); for(int i=0;i<vec[id].size();i++) { int to=vec[id][i]; if(to==from||to==dson[id]) continue; pos[to]=cnt2+maxd[to]; cnt2+=maxd[to]*2; get_dfn(to,id,to); } } int ans; #define f(i,j) f[dfn[i]+j] #define g(i,j) g[pos[top[i]]-dep[i]+dep[top[i]]+j] void solve(int id,int from) { f(id,0)=1; int tot=0; if(dson[id]) solve(dson[id],id); for(int i=0;i<vec[id].size();i++) { int to=vec[id][i]; if(to==from||to==dson[id]) continue; solve(to,id); g(id,0)+=g(to,1); for(int j=1;j<=maxd[to];j++) { tot+=(j<maxd[to])*g(to,j+1)*f(id,j)+g(id,j)*f(to,j-1); g(id,j)+=(j<maxd[to])*g(to,j+1)+f(id,j)*f(to,j-1); f(id,j)+=f(to,j-1); } } tot+=g(id,0); //cout<<id<<" find: "<<tot<<endl; ans+=tot; } signed main() { int n; //freopen("data.txt","r",stdin); cin>>n; for(int i=1;i<n;i++) { int a,b; scanf("%lld%lld",&a,&b); vec[a].push_back(b); vec[b].push_back(a); } get_maxd(1,0); pos[1]=maxd[1]; cnt2=maxd[1]*2; get_dfn(1,0,1); solve(1,0); cout<<ans; }