CF1260F Colored Tree
题意:
题解:考虑两个点\(i\)和\(j\),它们对答案造成的贡献是\(dis(i,j) \times (min(r_i,r_j)-max(l_i,l_j)+1) \times \prod_{k \neq i,j}(r_k-l_k+1)\)。
那么我们可以考虑枚举所有颜色,计算包含这个颜色的所有点两两之间的贡献。
设包含当前颜色的点集为\(S\),令\(P=\prod_{i=1}^n (r_i-l_i+1)\),\(g_i=r_i-l_i+1\),那么我们要求的就是:\(\sum_{i,j \in S}dis_{i,j}\times \frac{P}{g_i \times g_j}\)。提出\(P\),将\(dis_{i,j}\)化为我们能求的形式:\(\sum_{i,j\in S}(dep_i+dep_j-2\times dep_{lca(i,j)})\times \frac{1}{g_i\times g_j}\)。我们将括号展开,发现前面两项可以直接求,于是式子变成了:\(\sum_{i\in S}\frac{dep_i}{g_i} \times (\sum_{i\in S}\frac{1}{g_i}) -\sum_{i \in S}\frac{dep_i}{g_i^2}-2\times \sum_{i,j\in S} \frac{dep_{lca(i,j)}}{g_i\times g_j}\)。前面三个\(\sum\)我们都可以直接求,对于最后一个\(\sum\),我们用线段树维护\(\sum \frac{dep_{lca(i,j)}}{g_i}\),只需要对\(i\)到1路径上的节点进行区间加\(\frac{1}{g_i}\)即可。查询时直接查\(j\)到1路径上的权值和即可。注意如果\(dep_1=0\)时,需要减去1号节点的贡献。然后乘个\(\frac{1}{g_j}\)即可。
时间复杂度:\(O(nlog^2n)\)。
代码:
#include<bits/stdc++.h>
using namespace std;
#define re register int
#define F(x,y,z) for(re x=y;x<=z;x++)
#define FOR(x,y,z) for(re x=y;x>=z;x--)
typedef long long ll;
#define I inline void
#define IN inline int
#define STS system("pause")
template<class D>I read(D &res){
res=0;register D g=1;register char ch=getchar();
while(!isdigit(ch)){
if(ch=='-')g=-1;
ch=getchar();
}
while(isdigit(ch)){
res=(res<<3)+(res<<1)+(ch^48);
ch=getchar();
}
res*=g;
}
const int Mod=1e9+7,inv2=500000004;
typedef pair<int,int>pii;
vector<int>e[101000];
int n,m,cnt,ans,a[101000],b[101000],g[101000],S,A,B,C,D,posi[101000],top[101000],dep[101000],son[101000],siz[101000],id[101000],fa[101000];
int tr[404000],laz[404000];
pii p[202000];
I add(int &x,int y){(x+=y)>=Mod?x-=Mod:0;}
IN Plus(int x,int y){(x+=y)>=Mod?x-=Mod:0;return x;}
IN Pow(int x,int y=Mod-2){
re res=1;
while(y){
if(y&1)res=(ll)res*x%Mod;
x=(ll)x*x%Mod;
y>>=1;
}
return res;
}
IN count(int x,int y){
return (ll)(y-x+1)*(x+y)%Mod*inv2%Mod;
}
I D_1(int x,int fat,int depth){
dep[x]=depth;son[x]=-1;siz[x]=1;re maxi=-1;fa[x]=fat;
for(auto d:e[x]){
if(d==fat)continue;
D_1(d,x,depth+1);
siz[x]+=siz[d];if(maxi<siz[d])maxi=siz[d],son[x]=d;
}
}
I D_2(int x,int fat,int topi){
top[x]=topi;id[x]=++cnt;posi[cnt]=x;
if(son[x]!=-1)D_2(son[x],x,topi);
for(auto d:e[x]){
if(d==fat||d==son[x])continue;
D_2(d,x,d);
}
}
#define all 1,1,n
#define lt k<<1,l,mid
#define rt k<<1|1,mid+1,r
I Add(int k,int l,int r,int w){
// assert((r-l)==(dep[posi[r]]-dep[posi[l]]));
add(tr[k],(ll)(r-l+1)*w%Mod);
add(laz[k],w);
}
I push_down(int k,int l,int r){
re mid=(l+r)>>1;
Add(k<<1,l,mid,laz[k]);Add(k<<1|1,mid+1,r,laz[k]);laz[k]=0;
}
I modi(int k,int l,int r,int x,int y,int w){
if(x>r||y<l)return;
if(x<=l&&r<=y)return Add(k,l,r,w),void();
if(laz[k])push_down(k,l,r);
re mid=(l+r)>>1;
modi(lt,x,y,w);modi(rt,x,y,w);
tr[k]=Plus(tr[k<<1],tr[k<<1|1]);
}
IN ques(int k,int l,int r,int x,int y){
if(x>r||y<l)return 0;
if(x<=l&&r<=y)return tr[k];
if(laz[k])push_down(k,l,r);
re mid=(l+r)>>1;
return Plus(ques(lt,x,y),ques(rt,x,y));
}
I modify(int x,int sn){
re w=g[x];
while(x){
if(sn==1){
add(D,(ll)ques(all,id[top[x]],id[x])*w%Mod);
if(top[x]==1)add(D,Mod-(ll)w*ques(all,1,1)%Mod);
// cout<<"!"<<id[top[x]]<<" "<<id[x]<<endl;
modi(all,id[top[x]],id[x],w);
}
else{
modi(all,id[top[x]],id[x],Mod-w);
add(D,Mod-(ll)ques(all,id[top[x]],id[x])*w%Mod);
if(top[x]==1)add(D,(ll)w*ques(all,1,1)%Mod);
}
x=fa[top[x]];
}
}
I addin(int x){
// cout<<"A"<<x<<":";
add(A,(ll)dep[x]*g[x]%Mod);add(B,g[x]);add(C,(ll)dep[x]*g[x]%Mod*g[x]%Mod);
modify(x,1);
// cout<<A<<" "<<B<<" "<<C<<" "<<D<<endl;
}
I delet(int x){
// cout<<"B"<<x<<":";
add(A,Mod-(ll)dep[x]*g[x]%Mod);add(B,Mod-g[x]);add(C,Mod-(ll)dep[x]*g[x]%Mod*g[x]%Mod);
modify(x,-1);
// cout<<A<<" "<<B<<" "<<C<<" "<<D<<endl;
}
int main(){
read(n);S=m=1;
F(i,1,n)read(a[i]),read(b[i]),m=max(m,b[i]),g[i]=b[i]-a[i]+1,S=(ll)S*g[i]%Mod;
F(i,1,n)g[i]=Pow(g[i]);
re u,v;
F(i,1,n-1)read(u),read(v),e[u].emplace_back(v),e[v].emplace_back(u);
D_1(1,0,0);D_2(1,0,1);
// F(i,1,n)cout<<id[i]<<" ";cout<<endl;F(i,1,n)cout<<dep[i]<<" ";cout<<endl;
F(i,1,n)p[i]=make_pair(a[i],i),p[i+n]=make_pair(b[i]+1,-i);
sort(p+1,p+1+(n<<1));re now=1;
F(i,1,m){
while(now<=(n<<1)&&p[now].first==i){
if(p[now].second>0)addin(p[now].second);
else delet(-p[now].second);
now++;
}
add(ans,(ll)S*Plus((ll)A*B%Mod,Mod-Plus(C,Plus(D,D)))%Mod);
// cout<<ans<<" "<<A<<" "<<B<<" "<<C<<" "<<D<<endl;
}
printf("%d",ans);
return 0;
}
/*
4
1 1
1 2
1 1
1 2
1 2
1 3
3 4
*/