题解[P7815自傷無色]
不明白各位 dalao 为什么能推式子这么简洁。
来补一个 \(\text{dsu on tree+splay}\) 的大常数解法。
容易想到对每个点计算其为另外两点 \(\text{LCA}\) 的贡献。
最朴素的方法就是对每个点暴力统计其每个儿子的子树与其之前搜过的儿子的子树中每个点之间的贡献。
但这显然可以通过 \(\text{dsu on tree}\) 优化这个暴力过程。
剩下的几乎只用推式子了。
设当前处理 \(x\) 节点,定义集合 \(S\) 为目前之前搜过子树的点集。
初始时 \(S\) 即是重儿子子树内全部点。
设 \(x\) 的子树内节点 \(i\) 到 \(x\) 的距离为 \(d_i\) ,当前希望与之前计算贡献的点与 \(x\) 的距离为 \(dis\)
再定义两个互不相交的集合 \(A=\left\{i|i\in S,d_i\leq dis\right\}\) \(B=\left\{i|i\in S,d_i>dis\right\}\)
先算 \(A\) 集合与当前点的贡献:
设 \(\left[l_i,r_i\right]\) 为 \(d_i\) 与 \(dis\) 为两边时另一条边的可取范围。
则 \(l_i=dis-d_i+1,r_i=dis+d_i-1\)
则树三角边权和的总贡献为:
\(\sum\limits_{i\in A}^{}\sum\limits_{j=l_i}^{r_i}dis+d_i+j\)
\(=\sum\limits_{i\in A}^{}(r_i-l_i+1)(dis+d_i)+\frac{1}{2}(r_i-l_i+1)(r_i+l_i)\)
\(=\sum\limits_{i\in A}^{}(r_i-l_1+1)(2dis+d_i)\)
\(=\sum\limits_{i\in A}^{}(2d_i-1)(2dis+d_i)\)
\(=(4dis-1)\sum\limits_{i\in A}d_i-2dis\sum\limits_{i\in A}1+2\sum\limits_{i\in A}{d_i}^2\)
树三角数量的总贡献为:
\(\sum\limits_{i\in A}r_i-l_i+1=\sum\limits_{i\in A}2d_i-1=2\sum\limits_{i\in A}d_i-\sum\limits_{i\in A}1\)
而对于 \(B\) 集合来说
\(l_i=d_i-dis+1,r_i=d_i+dis-1\)
树三角边权和总贡献为:
\(\sum\limits_{i\in B}\sum\limits_{j=l_i}^{r_i}dis+d_i+j\)
\(=\sum\limits_{i\in B}^{}(r_i-l_i+1)(dis+d_i)+\frac{1}{2}(r_i-l_i+1)(r_i+l_i)\)
\(=\sum\limits_{i\in B}^{}(r_i-l_1+1)(dis+2d_i)\)
\(=\sum\limits_{i\in B}^{}(2dis-1)(dis+2d_i)\)
\(=2dis^2\sum\limits_{i\in B}1+(2dis-1)\sum\limits_{i\in B}d_i\)
树三角数量的总贡献为:
\(\sum\limits_{i\in B}r_i-l_i+1=\sum\limits_{i\in B}2dis-1=(2dis-1)\sum\limits_{i\in B}1\)
所以只需能确定 \(A,B\) 集合各自的大小、和、平方和就可以顺利计算了。
又由于边权 \(\leq10^9\) 点到 \(x\) 的距离 \(\leq 10^{14}\) 爆不了 \(\text{long long}\)
所以写一个平衡树便可方便维护。
而为了确定 \(A,B\) 集合只需知道 \(x\) 子树内点到 \(x\) 相对大小关系,
这与这些点到根的相对大小关系完全相同,这很方便预处理出来。
但还有一个问题:
计算时为了利用重儿子的信息,而重儿子中的距离并不是到 \(x\) 的。
于是还要打个标记进行适当变换:
设之前维护的大小、和、平方和分别为 \(c,s,s2\) ,真正需要的是 \(\acute c,\acute s,\acute{s2}\) ,重儿子到 \(x\) 的边权为 \(v\)
于是有:
\(\begin{cases}\acute c=c\\\acute s=\sum d_i+v=s+vc\\\acute{s2}=\sum(d_i+w)^2=s2+2sv+v^2c=s2+2v\acute s-v^2c\end{cases}\)
而这个转换的标记显然有可加性,于是这题就做完了。
还有一个细节:
为了保证不会让同一儿子的子树内的点又算一次贡献,需要在每个点查询完后统一插入平衡树中。
代码:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1e5+10,mod=1e9+7;ll inf=1e15;
int n,m,x,y,v,tot;ll ans,ans0,d[N],q[N],w[N<<1],sw[N];char ch;
int to[N<<1],nextn[N<<1],h[N],edg,son[N],id[N],sz[N];
inline void add(int x,int y,int v){to[++edg]=y;nextn[edg]=h[x];h[x]=edg;w[edg]=v;}
inline void read(int &x){
x=0;ch=getchar();
while(ch<'0'||ch>'9')ch=getchar();
while(ch>='0'&&ch<='9')x=x*10+ch-48,ch=getchar();
}
void write(int x){if(x>=10)write(x/10);putchar(48+x%10);}
ll qpow(ll x,int k){ll res=1;while(k){if(k&1)res=res*x%mod;x=x*x%mod;k>>=1;}return res;}
struct node{
int c;ll s,s2;
node()=default;
node(int _c,ll _s,ll _s2):c(_c),s(_s),s2(_s2){}
void operator +=(const node &tmp){c+=tmp.c,s=(s+tmp.s)%mod,s2=(s2+tmp.s2)%mod;}
void trans(ll v){ll _s=s;s=(v*c+_s)%mod;s2=(v*v%mod*c%mod+2*v*_s%mod+s2)%mod;}
}tmp;
struct Splay{
int son[N][2],anc[N],rt;ll tag[N];node w[N],s[N];
void init(){q[n+1]=0;q[n+2]=inf;rt=n+1;son[n+1][1]=n+2;anc[n+2]=n+1;}
bool p(int x){return son[anc[x]][1]==x;}
void tag_(int x,ll v){
tag[x]=(tag[x]+v)%mod;if(x<=n)w[x].trans(v);s[x].trans(v);
}
void fix(int x){if(!x)return;s[x]=w[x];s[x]+=s[son[x][0]];s[x]+=s[son[x][1]];}
void rotate(int x){
int y=anc[x],xx=anc[y];bool b=p(x),bb=p(y);
anc[x]=xx;if(xx)son[xx][bb]=x;
son[y][b]=son[x][!b];anc[son[x][!b]]=y;
son[x][!b]=y;anc[y]=x;fix(y);fix(x);fix(xx);
}
void pushdown(int x){
if(tag[x]){
if(son[x][0])tag_(son[x][0],tag[x]);
if(son[x][1])tag_(son[x][1],tag[x]);
tag[x]=0;
}
}
void pushall(int x){if(anc[x])pushdown(anc[x]);pushdown(x);}
void splay(int x,int y=0){
if(!y)rt=x;pushall(x);
for(int i;i=anc[x],anc[x]^y;rotate(x))
if(anc[i]^y)rotate(p(x)==p(i)?i:x);
}
void insert(int x,node v){
s[x]=w[x]=v;int y=rt,f=0;
for(;;){
if(!y){anc[x]=f;son[f][q[f]<=q[x]]=x;return splay(x);}
pushdown(y);f=y;y=son[y][q[y]<=q[x]];
}
}
int pre(int x){
splay(x);
x=son[x][0];
while(son[x][1])x=son[x][1];
return splay(x),x;
}
void del(int x){
int y=pre(x);splay(x,y);
anc[son[y][1]=son[x][1]]=y,fix(y);
son[x][0]=son[x][1]=anc[x]=0;
}
void inquiry(int x,ll v,node vv){
insert(x,vv);v%=mod;
tmp=s[son[x][0]];
ans=(ans+(4*v-1)%mod*tmp.s%mod-2*v*tmp.c%mod+2*tmp.s2%mod)%mod;
ans0=(ans0+2*tmp.s-tmp.c)%mod;
tmp=s[son[x][1]];
ans=(ans+(4*v-2)%mod*tmp.s%mod+(2*v*v%mod-v)*tmp.c%mod)%mod;
ans0=(ans0+(2*v-1)*tmp.c%mod)%mod;del(x);
}
}S;
void init(int x,int anc){
int i,y;sz[x]=1;
for(i=h[x];y=to[i],i;i=nextn[i])if(y^anc){
q[y]=q[x]+w[i];init(y,x);sz[x]+=sz[y];
if(sz[y]>sz[son[x]])son[x]=y,sw[x]=w[i];
}
}
void dfs(int x,int anc){
int i,y;id[++tot]=x;
for(i=h[x];y=to[i],i;i=nextn[i])if(y^anc)d[y]=(d[x]+w[i])%mod,dfs(y,x);
}
void clear(int x,int anc){
int i,y;S.del(x);
for(i=h[x];i;i=nextn[i])if(to[i]^anc)clear(to[i],x);
}
void solve(int x,int anc){
int i,j,y;
for(i=h[x];y=to[i],i;i=nextn[i])if(y^anc&&y^son[x])solve(y,x),clear(y,x);
if(son[x])solve(son[x],x),S.tag_(S.rt,sw[x]);
for(i=h[x];y=to[i],i;i=nextn[i])if(y^anc&&y^son[x]){
tot=0;d[y]=w[i];dfs(y,x);
for(j=1;y=id[j],j<=tot;++j)
S.inquiry(y,d[y],node(1,d[y],d[y]*d[y]%mod));
for(j=1;y=id[j],j<=tot;++j)
S.insert(y,node(1,d[y],d[y]*d[y]%mod));
}
S.insert(x,node(1,0,0));
}
main(){
read(n);register int i;S.init();
for(i=1;i^n;++i)read(x),read(y),read(v),add(x,y,v),add(y,x,v);
init(1,0);solve(1,0);ans=(ans+mod)%mod;ans0=(ans0+mod)%mod;
ans0=qpow(ans0,mod-2);write(ans*ans0%mod);
}