bzoj 3091: 城市旅行 LCT
题目:
题解:
首先前三个操作就是裸的LCT模板
只考虑第四个操作.
要求我们计算期望,所以我们考虑计算出所有情况的和然后在除以情况的数目.
这样我们就找到分子分母了.
我们很容易发现分母即为\(\frac{n*(n+1)}{2}\)
对应到我们的Splay树上即\(\frac{siz*(siz+1)}{2}\)
所以我们现在考虑维护分子:
对于首先我们考虑在一个长为n的序列上统计这些东西
我们知道总和即为每一项乘以这一项出现的次数(又在废话)
出现的次数又是多少呢?
所以其实对于每一个元素,出现的次数都是\((\text{左边的}siz+1)*(\text{右边的}siz+1)\)
那么我们考虑合并:
假设这个区间作为合并的左区间,我们设\(w = \text{右区间的}siz+1\)(即这个区间合并后右侧新出现的节点数)
那么按照刚才的思路,所有的数字的后一项都会同时增大即变为:
于是我们发现实际上这段区间的贡献增加了:
所以我们记录一个和表示\(1*a_1 + 2*a_2 + 3*a_3 + ... + n*a_n\)即可
利用这个我们就可以维护分子了。啥?? 怎么维护 ??
\(val = ch[0]->val + ch[1]->val + ch[0]->lsum*(ch[1]->siz + 1) + ch[1]->rsum*(ch[0]->siz + 1) + w*(ch[0]->siz + 1)*(ch[1]->siz + 1);\)
其中\(w\)为节点本身的权,\(lsum = \sum_{i=1}^{n}a_i*i\),\(rsum = \sum_{i=1}^{n}a_i*(n-i+1)\)
至于维护\(lsum\)和\(rsum\)的过程.
我们有
据说是小学数学难度.
夭折啊 !我想不出来 !
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef unsigned long long ll;
inline void read(ll &x){
x=0;char ch;bool flag = false;
while(ch=getchar(),ch<'!');if(ch == '-') ch=getchar(),flag = true;
while(x=10*x+ch-'0',ch=getchar(),ch>'!');if(flag) x=-x;
}
const ll maxn = 50010;
struct Node{
Node *ch[2],*fa;
ll w,lsum,rsum,lazy;
ll val,siz,tag,sum;
void update();
void pushdown();
void rev();
void inc(ll x);
}*null;
void Node::rev(){
if(this == null) return;
swap(lsum,rsum);swap(ch[0],ch[1]);
tag ^= 1;
}
void Node::inc(ll x){
if(this == null) return;
w += x;sum += x*siz;
lsum += x*siz*(siz+1)/2;
rsum += x*siz*(siz+1)/2;
val += x*siz*(siz+1)*(siz+2)/6;
lazy += x;
}
void Node::pushdown(){
if(this == null) return;
if(lazy){
if(ch[0] != null) ch[0]->inc(lazy);
if(ch[1] != null) ch[1]->inc(lazy);
lazy = 0;
}
if(tag){
if(ch[0] != null) ch[0]->rev();
if(ch[1] != null) ch[1]->rev();
tag = 0;
}
}
void Node::update(){
if(this == null) return;
siz = ch[0]->siz + ch[1]->siz + 1;
sum = ch[0]->sum + ch[1]->sum + w;
lsum = ch[0]->lsum + w*(ch[0]->siz + 1) + ch[1]->lsum + ch[1]->sum*(ch[0]->siz + 1);
rsum = ch[1]->rsum + w*(ch[1]->siz + 1) + ch[0]->rsum + ch[0]->sum*(ch[1]->siz + 1);
val = ch[0]->val + ch[1]->val + ch[0]->lsum*(ch[1]->siz + 1) + ch[1]->rsum*(ch[0]->siz + 1) + w*(ch[0]->siz + 1)*(ch[1]->siz + 1);
}
Node mem[maxn],*it;
inline void init(){
it = mem;null = it++;
null->ch[0] = null->ch[1] = null->fa = null;
null->w = null->lsum = null->rsum = null->sum =
null->val = null->siz = null->tag = 0;
}
inline Node* newNode(ll x){
Node *p = it++;p->ch[0] = p->ch[1] = p->fa = null;
p->w = p->val = p->lsum = p->rsum = p->sum = x;p->siz = 1;
p->tag = p->lazy = 0;
return p;
}
inline void rotate(Node *p,Node *x){
ll k = p == x->ch[1];
Node *y = p->ch[k^1],*z = x->fa;
if(z->ch[0] == x) z->ch[0] = p;
if(z->ch[1] == x) z->ch[1] = p;
if(y != null) y->fa = x;
p->fa = z;p->ch[k^1] = x;
x->fa = p;x->ch[k] = y;
x->update();p->update();
}
inline bool isRoot(Node *p){
return (p == null) || (p->fa->ch[0] != p && p->fa->ch[1] != p);
}
inline void Splay(Node *p){
p->pushdown();
while(!isRoot(p)){
Node *x = p->fa,*y = x->fa;
y->pushdown();x->pushdown();p->pushdown();
if(isRoot(x)) rotate(p,x);
else if((p == x->ch[0])^(x == y->ch[0])) rotate(p,x),rotate(p,y);
else rotate(x,y),rotate(p,x);
}p->update();
}
inline Node* Access(Node *x){
for(Node *y = null;x != null;y = x,x = x->fa)
Splay(x),x->ch[1] = y,x->update();
return x;
}
inline void makeRoot(Node *x){
Access(x);Splay(x);x->rev();
}
inline void link(Node *x,Node *y){
makeRoot(x);x->fa = y;
}
inline void cut(Node *x,Node *y){
makeRoot(x);Access(y);Splay(y);
if(y->ch[0] == x && x->ch[1] == null){
y->ch[0] = y->ch[0]->fa = null;
y->update();
}
}
inline void inc(Node *x,Node *y,ll w){
makeRoot(x);Access(y);Splay(y);
y->inc(w);
}
inline ll gcd(const ll &a,const ll &b){return b == 0 ? a : gcd(b,a%b);}
inline ll query(Node *x,Node *y){
makeRoot(x);Access(y);Splay(y);
ll upside = y->val;
ll dnside = y->siz*(y->siz + 1)/2;
ll g = gcd(upside,dnside);
printf("%llu/%llu\n",upside/g,dnside/g);
}
inline Node* findRoot(Node *x){
Access(x);Splay(x);
while(x->ch[0] != null) x = x->ch[0];
Splay(x);return x;
}
int main(){
init();
ll n,m;read(n);read(m);
for(ll i=1,x;i<=n;++i){
read(x);newNode(x);
}
ll u,v;
for(ll i=1;i<n;++i){
read(u);read(v);
link(mem+u,mem+v);
}
ll op;
while(m--){
read(op);
if(op == 1){
read(u);read(v);
if(u != v && findRoot(mem+u) == findRoot(mem+v)) cut(mem+u,mem+v);
}else if(op == 2){
read(u);read(v);
if(findRoot(mem+u) != findRoot(mem+v)) link(mem+u,mem+v);
}else if(op == 3){
read(u);read(v);read(op);
if(findRoot(mem+u) == findRoot(mem+v)) inc(mem+u,mem+v,op);
}else if(op == 4){
read(u);read(v);
if(findRoot(mem+u) == findRoot(mem+v)) query(mem+u,mem+v);
else puts("-1");
}
}
getchar();getchar();
return 0;
}