BZOJ 2752 [HAOI2012]高速公路(road):线段树【维护区间内子串和】
题目链接:http://www.lydsy.com/JudgeOnline/problem.php?id=2752
题意:
有一个初始全为0的,长度为n的序列a。
有两种操作:
(1)C l r v: 将[l,r)内的数全部加v。
(2)Q l r: 在[l,r)内随机选两个数x,y(x < y),问你∑(a[x to y])的期望,用最简分数形式输出。
题解:
首先,题中要求的期望 = 区间内所有子串之和 / 区间内子串个数。
如果一个区间的长度为len,显然区间内的子串个数为len*(len+1)/2。
所以题目就变成了怎样维护区间内所有子串之和。
dat表示某个区间的子串和。
假设有两个相邻区间l,r,合并起来的区间叫x。
那么dat[x] = dat[x] + dat[y] + 跨两个区间的子串和
所以接下来考虑如何求跨区间的子串和。
sum表示某个区间的所有元素之和。
ln表示区间l的长度,rn表示区间r的长度。
ls表示某个区间的所有所有前缀之和,rs表示某个区间的所有后缀之和。
则跨区间的子串之和 = rs[l]*rn + ls[r]*ln
即dat[x] = dat[x] + dat[y] + rs[l]*rn + ls[r]*ln
ls,rs和sum的合并就很好求了:
ls[x] = ls[l] + rn*sum[l] + ls[r]
rs[x] = rs[r] + ln*sum[r] + rs[l]
sum[x] = sum[l] + sum[r]
这样线段树的pushup函数就写完了。
然后考虑如何pushdown传标记。
tag表示某个区间被同时加了多少。
现在只考虑当前节点x的某一个儿子y,儿子y的区间长度为len。
首先考虑tag[x]对dat[y]的贡献。
贡献 = 枚举子串的长度 * 这种长度的子串个数 * tag[x]
即:dat[y] += ∑ i*(len-i+1)*tag[x],其中i∈[1,len]。
化简得:dat[y] += ( len*(len+1)/2*(len+1) + ∑(i^2) ) * tag[x]
对于其中的∑(i^2),事先O(n)预处理出来一个平方前缀和数组sqr即可。
然后易得tag[x]对ls,rs,sum的贡献:
ls[y] += len*(len+1)/2*tag[x]
rs[y] += len*(len+1)/2*tag[x]
sum[y] += len*tag[x]
这样pushdown也就写好了。
然后大力线段树即可QAQ……
AC Code:
1 #include <iostream> 2 #include <stdio.h> 3 #include <string.h> 4 #include <algorithm> 5 #define MAX_N 100005 6 #define MAX_T 400005 7 #define int ll 8 9 using namespace std; 10 11 typedef long long ll; 12 13 struct Node 14 { 15 int dt,ls,rs,s,ln; 16 Node(int _dt,int _ls,int _rs,int _s,int _ln) 17 { 18 dt=_dt; ls=_ls; rs=_rs; s=_s; ln=_ln; 19 } 20 Node(){} 21 friend Node mix(const Node &a,const Node &b) 22 { 23 int _dt=a.dt+b.dt+a.rs*b.ln+b.ls*a.ln; 24 int _ls=a.ls+b.ln*a.s+b.ls; 25 int _rs=b.rs+a.ln*b.s+a.rs; 26 int _s=a.s+b.s; 27 int _ln=a.ln+b.ln; 28 return Node(_dt,_ls,_rs,_s,_ln); 29 } 30 }; 31 32 int n,m; 33 int ls[MAX_T]; 34 int rs[MAX_T]; 35 int dat[MAX_T]; 36 int sum[MAX_T]; 37 int tag[MAX_T]; 38 int sqr[MAX_N]; 39 40 void cal_sqr() 41 { 42 for(int i=1;i<=n;i++) sqr[i]=sqr[i-1]+i*i; 43 } 44 45 void push_up(int x,int len) 46 { 47 int l=x*2+1,r=x*2+2; 48 Node L(dat[l],ls[l],rs[l],sum[l],len-(len>>1)); 49 Node R(dat[r],ls[r],rs[r],sum[r],(len>>1)); 50 Node tmp=mix(L,R); 51 dat[x]=tmp.dt; 52 ls[x]=tmp.ls; 53 rs[x]=tmp.rs; 54 sum[x]=tmp.s; 55 } 56 57 void push_down(int x,int len) 58 { 59 if(tag[x]) 60 { 61 int l=x*2+1,r=x*2+2; 62 int ln=(len-(len>>1)),rn=(len>>1); 63 dat[l]+=(ln*(ln+1)/2*(ln+1)-sqr[ln])*tag[x]; 64 dat[r]+=(rn*(rn+1)/2*(rn+1)-sqr[rn])*tag[x]; 65 ls[l]+=ln*(ln+1)/2*tag[x]; 66 ls[r]+=rn*(rn+1)/2*tag[x]; 67 rs[l]+=ln*(ln+1)/2*tag[x]; 68 rs[r]+=rn*(rn+1)/2*tag[x]; 69 sum[l]+=ln*tag[x]; 70 sum[r]+=rn*tag[x]; 71 tag[l]+=tag[x]; 72 tag[r]+=tag[x]; 73 tag[x]=0; 74 } 75 } 76 77 void update(int a,int b,int k,int l,int r,int x) 78 { 79 if(a<=l && r<=b) 80 { 81 int len=r-l+1; 82 tag[k]+=x; 83 sum[k]+=len*x; 84 ls[k]+=len*(len+1)/2*x; 85 rs[k]+=len*(len+1)/2*x; 86 dat[k]+=(len*(len+1)/2*(len+1)-sqr[len])*x; 87 return; 88 } 89 if(r<a || b<l) return; 90 push_down(k,r-l+1); 91 int mid=(l+r)>>1; 92 update(a,b,k*2+1,l,mid,x); 93 update(a,b,k*2+2,mid+1,r,x); 94 push_up(k,r-l+1); 95 } 96 97 Node query(int a,int b,int k,int l,int r) 98 { 99 if(a<=l && r<=b) return Node(dat[k],ls[k],rs[k],sum[k],r-l+1); 100 if(r<a || b<l) return Node(0,0,0,0,0); 101 push_down(k,r-l+1); 102 int mid=(l+r)>>1; 103 Node v1=query(a,b,k*2+1,l,mid); 104 Node v2=query(a,b,k*2+2,mid+1,r); 105 return mix(v1,v2); 106 } 107 108 signed main() 109 { 110 scanf("%lld%lld",&n,&m); 111 n--; 112 cal_sqr(); 113 char opt[16]; 114 int l,r,v; 115 while(m--) 116 { 117 scanf("%s%lld%lld",opt,&l,&r); 118 if(opt[0]=='C') 119 { 120 scanf("%lld",&v); 121 update(l,r-1,0,1,n,v); 122 } 123 else 124 { 125 int dt=query(l,r-1,0,1,n).dt; 126 int len=r-l; 127 int tot=len*(len+1)/2; 128 int g=__gcd(dt,tot); 129 printf("%lld/%lld\n",dt/g,tot/g); 130 } 131 } 132 }