线段树学习笔记
(未完待续)
推荐参考: notonlysuccess神犇的线段树总结
http://blog.csdn.net/kzzhr/article/details/10813301
(1)单点更新
HDU 1166 区间和
HDU 1754 区间最值
HDU 1394 区间和
HDU 2795 区间最值
常用模板:
Struct treetype { int l,int r,int dat; }t[3000]; //如果要对长度为n的区间建线段树,那么t数组至少要打到3*n
建树:
void build(int l,int r,int o) { if (o>num) num=o; //num:计数用,记录线段树上已有多少节点 t[o].l=l; t[o].r=r; if (l==r) t[o].dat=a[l]; //更新只包含一个点的区间(即树上的叶子节点) else { int mid=(l+r)/2; build(l,mid,2*o); build(mid+1,r,2*o+1); t[o].dat=max/min/sum(t[2*o].dat,t[2*o+1].dat); //根据要求而定 } }
更新节点:
- 对于求区间和的问题:
更新:令a[x]=m;
void update(int l,int r,int o) //a[x]=m; { t[o].l=l; t[o].r=r; if ((l==x)&&(r==x)) //找到点节点:更新点 t[o].dat=m; else { int mid=(l+r)/2; if (x<=mid) //否则:对左or右区间进行处理 { int tmp=t[2*o].dat; update(l,mid,2*o); t[o].dat+=t[2*o].dat-tmp; } else { int tmp=t[2*o+1].dat; update(mid+1,r,2*o+1); t[o].dat+=t[2*o+1].dat-tmp; } } }
2. 对于求区间最值的问题:
更新:令a[tx]=ty;
void update(int l,int r,int o) { if (o>num) return; //防止出现o不停增加的死循环 (其实加不加也无所谓-_-||) if ((l==tx)&&(r==tx)) //找到点节点:更新点 t[o].dat=ty; else { int mid=(l+r)/2; if (tx<=mid) //否则:对左or右区间进行处理 { update(l,mid,2*o); t[o].dat=max(t[o].dat,t[2*o].dat); } else { update(mid+1,r,2*o+1); t[o].dat=max(t[o].dat,t[2*o+1].dat); } } }
求区间上的某种性质:
- 求区间和
int query_sum(int l,int r,int o) { if (l>r) //不可行的情况 return 0; else if (l==r) return a[l]; else { int tl=t[o].l,tr=t[o].r; if ((l==tl)&&(r==tr)) //正好命中线段树上的一整块区间,直接返回值即可 return t[o].dat; else { int mid=(tl+tr)/2; //否则拆开:l...mid和mid+1...r if (r<=mid) //注意红字部分! return (query_sum(l,r,2*o)); else if (l>mid) return (query_sum(l,r,2*o+1)); else return (query_sum(l,mid,2*o)+query_sum(mid+1,r,2*o+1)); } } }
2.求区间最值(以最大值为例):
int query_max(int l,int r,int o) { if (o>num) return 0; //不可行的情况 int tl=t[o].l,tr=t[o].r; if ((l==tl)&&(r==tr)) //正好命中线段树上的一整块区间,直接返回值即可 return t[o].dat; else { int mid=(tl+tr)/2; //否则拆开:l...mid和mid+1...r if (r<=mid) return (query_max(l,r,2*o)); else if (l>mid) return (query_max(l,r,2*o+1)); else if ((l<=mid)&&(mid<=r)) { int tm1=query_max(l,mid,2*o); int tm2=query_max(mid+1,r,2*o+1); return max(tm1,tm2); } } }
主程序:
cin>>n; for i:=1 to n do cin>>a[i]; build(1,n,1); //建树 for i:=1 to m do { cin>>tx>>ty; a[tx]=ty; update(1,n,1); //单点更新 } cin>>ml>>mr; ans1=query_sum(ml,mr,1); //求区间和 ans2=query_max(ml,mr,1); //求区间最值
备注:模板中的除法操作还可以用位运算优化:int mid=(l+r)>>1;
附录:网上的一段模板(hdu 1166)
1 #include<stdio.h> 2 #define MAX 50000+10 3 int head[MAX]; 4 struct node 5 { 6 int l,r; 7 int value; 8 }tree[3*MAX]; 9 10 void build(int l,int r,int v) //建树 11 { 12 tree[v].l=l; 13 tree[v].r=r; 14 if(l==r) 15 { 16 tree[v].value=head[l]; 17 return ; 18 } 19 int mid=(l+r)>>1; 20 build(l,mid,v*2); 21 build(mid+1,r,v*2+1); 22 tree[v].value=tree[v+v].value+tree[v+v+1].value; 23 } 24 25 int queue(int a,int b,int v) //求和 26 { 27 if(tree[v].l==a&&tree[v].r==b) 28 { 29 return tree[v].value; 30 } 31 int mid=(tree[v].l+tree[v].r)>>1; 32 if(b<=mid) return queue(a,b,v+v); 33 else if(a>mid) return queue(a,b,v+v+1); 34 else return queue(a,mid,v*2)+queue(mid+1,b,v*2+1); 35 } 36 37 void update(int a,int b,int v) //更新: head[a]+=b; 38 { 39 tree[v].value+=b; 40 if(tree[v].l==tree[v].r) 41 { 42 return ; 43 } 44 int mid=(tree[v].l+tree[v].r)>>1; 45 if(a<=mid) update(a,b,v*2); 46 else update(a,b,v*2+1); 47 } 48 49 int main() 50 { 51 int T,N; 52 char str[10]; 53 int a,b; 54 while(scanf("%d",&T)>0) 55 { 56 for(int h=0;h<T;h++) 57 { 58 scanf("%d",&N); 59 for(int i=1;i<=N;i++) 60 { 61 scanf("%d",&head[i]); 62 } 63 build(1,N,1); 64 printf("Case %d:\n",h+1); 65 while(scanf("%s",str)>0) 66 { 67 if(str[0]=='A') 68 { 69 scanf("%d %d",&a,&b); 70 update(a,b,1); 71 } 72 else if(str[0]=='S') 73 { 74 scanf("%d %d",&a,&b); 75 update(a,-b,1); 76 } 77 else if(str[0]=='Q') 78 { 79 scanf("%d %d",&a,&b); 80 int ans=queue(a,b,1); 81 printf("%d\n",ans); 82 } 83 else 84 { 85 break; 86 } 87 } 88 89 } 90 } 91 }
---------------------------------------------------
2.成段更新
POJ3468 更新+求和
HDU1698 模板题-_-||
AHOI 2009 变形:加法+乘法混合
POJ 2528 +离散化
struct { int l,r; __int64 dat; }t[300010]; char ch; __int64 ml,mr,md; __int64 add[300010]; 建树: void build(long l,long r,long long o) { if (o>num) num=o; t[o].l=l; t[o].r=r; if (l==r) t[o].dat=a[l]; else { long mid=(l+r)/2; build(l,mid,2*o); build(mid+1,r,2*o+1); t[o].dat=t[2*o].dat+t[2*o+1].dat; } } 更新:cin>>ml>>mr>>md; update(ml,mr,1); ->令a[ml..mr]+=md void update(long l,long r,long long o) { if (o>num) return; long tl=t[o].l,tr=t[o].r; if ((tl==l)&&(tr==r)) { t[o].dat+=(tr-tl+1)*md; add[o]+=md; return; } else { long mid=(tl+tr)/2; //if (tl==tr) return; push_down(o,tl,mid,tr); if (r<=mid) update(l,r,2*o); else if (l>mid) update(l,r,2*o+1); else { update(l,mid,2*o); update(mid+1,r,2*o+1); } t[o].dat=t[2*o].dat+t[2*o+1].dat; } } push_down:(向下更新一层add数组) void push_down(long long o,long l,long mid,long r) { if (add[o]!=0) { long rl=2*o,rr=2*o+1; __int64 tm=add[o]; add[rl]+=tm; add[rr]+=tm; t[rl].dat+=tm*(mid-l+1); t[rr].dat+=tm*(r-mid); add[o]=0; } } 求和:cin>>ml>>mr>>md; query_sum(ml,mr,1); ->求和:a[ml...mr] __int64 query_sum(long l,long r,long long o) { if (o>num) return 0; long tl=t[o].l,tr=t[o].r; if ((tl==l)&&(tr==r)) return t[o].dat; else { long mid=(tl+tr)/2; push_down(o,tl,mid,tr); if (r<=mid) return query_sum(l,r,2*o); else if (l>mid) return query_sum(l,r,2*o+1); else return (query_sum(l,mid,2*o)+query_sum(mid+1,r,2*o+1)); } }
加离散化:(POJ 2528为例)
我们要重新给压缩后的线段标记起点和终点。
按照通用的离散化方法。。。。
首先依次读入线段端点坐标,存于post[MAXN][2]中,post[i][0]存第一条线段的起点,post[i][1]存第一条线段的终点,然后用一个结构题数组line[MAXN]记录信息,line[i].li记录端点坐标,line[i].num记录这个点属于哪条线段(可以用正负数表示,负数表示起点,正数表示终点)。假如有N条线段,就有2*N个端点。然后将line数组排序,按照端点的坐标,从小到大排。接着要把线段赋予新的端点坐标了。从左到右按照递增的次序,依次更新端点,假如2*N个点中,共有M个不同坐标的点,那么线段树的范围就是[1,M]。
memset(v,false,sizeof(v)); memset(t,0,sizeof(t)); num=0; cin>>n; for (int i=1;i<=n;i++) { cin>>ml>>mr; sm[2*i-1].nm=ml; sm[2*i].nm=mr; sm[2*i-1].ky=i; //start: >0 sm[2*i].ky=-i; //end: <0 } isort(1,2*n); //sort: based on sm[i].nm tmp=0; for (int i=1;i<=2*n;i++) { if (sm[i].nm!=sm[i-1].nm) tmp++; if ((sm[i].nm-sm[i-1].nm==2)&&(i>1)) tmp++; int tkey=sm[i].ky; if (tkey>0) sgm[tkey].st=tmp; else if (tkey<0) sgm[-tkey].ed=tmp; } build(1,tmp,1); for (int i=1;i<=n;i++) { ml=sgm[i].st; mr=sgm[i].ed; md=i; update(ml,mr,1); //a[ml..mr]=md; } ans=0; sum(1); cout<<ans<<endl;
AHOI2009 Seq
1 AHOI 2009 seq: 2 #include <iostream> 3 #include <cstdio> 4 #include <cstring> 5 using namespace std; 6 7 struct 8 { 9 int l,r; 10 long long dat; 11 }t[300010]; 12 13 long long a[100010],add[300010],mul[300010]; 14 int n,m,ml,mr,opr; 15 long long md,p,tmp,num,ans; 16 17 void modp(long long *nm) 18 { 19 *nm=((*nm)%p); 20 } 21 22 void build(int l,int r,long long o) 23 { 24 if (o>num) num=o; 25 mul[o]=1; add[o]=0; 26 t[o].l=l; t[o].r=r; 27 if (l==r) 28 t[o].dat=a[l]; 29 else 30 { 31 int mid=(l+r)/2; 32 build(l,mid,2*o); 33 build(mid+1,r,2*o+1); 34 t[o].dat=t[2*o].dat+t[2*o+1].dat; 35 modp(&t[o].dat); 36 } 37 } 38 /* 39 void push_down1(long long o,int l,int mid,int r) //a[i]=a[i]*c 40 { 41 if (mul[o]!=1) 42 { 43 long long tmp1=mul[o]; 44 mul[2*o]=mul[2*o]*tmp1; 45 mul[2*o+1]=mul[2*o+1]*tmp1; 46 t[2*o].dat=t[2*o].dat*tmp1; 47 t[2*o+1].dat=t[2*o+1].dat*tmp1; 48 modp(&mul[2*o]); 49 modp(&mul[2*o+1]); 50 modp(&t[2*o].dat); 51 modp(&t[2*o+1].dat); 52 mul[o]=1; 53 } 54 } 55 56 void push_down2(long long o,int l,int mid,int r) //a[i]=a[i]+c; 57 { 58 if (add[o]!=0) 59 { 60 long long tmp2=add[o]; 61 add[2*o]+=tmp2; 62 add[2*o+1]+=tmp2; 63 t[2*o].dat+=tmp2*(mid-l+1); 64 t[2*o+1].dat+=tmp2*(r-mid); 65 modp(&add[2*o]); 66 modp(&add[2*o+1]); 67 modp(&t[2*o].dat); 68 modp(&t[2*o+1].dat); 69 add[o]=0; 70 } 71 } 72 */ 73 void push_down(long long o,int l,int mid,int r) 74 { 75 t[2*o].dat=(t[2*o].dat*mul[o]+(mid-l+1)*add[o])%p; 76 t[2*o+1].dat=(t[2*o+1].dat*mul[o]+(r-mid)*add[o])%p; 77 78 mul[2*o]=(mul[2*o]*mul[o])%p; 79 mul[2*o+1]=(mul[2*o+1]*mul[o])%p; 80 81 add[2*o]=(add[2*o]*mul[o]+add[o])%p; 82 add[2*o+1]=(add[2*o+1]*mul[o]+add[o])%p; 83 84 mul[o]=1; add[o]=0; 85 } 86 87 long long query_sum(int l,int r,long long o) 88 { 89 if (o>num) return 0; 90 int tl=t[o].l,tr=t[o].r; 91 if ((tl==l)&&(tr==r)) 92 return t[o].dat; 93 else 94 { 95 int mid=(tl+tr)/2; 96 97 //if (opr==1) 98 // push_down1(o,tl,mid,tr); 99 //else if (opr==2) 100 // push_down2(o,tl,mid,tr); 101 push_down(o,tl,mid,tr); 102 103 if (r<=mid) 104 { 105 tmp=query_sum(l,r,2*o); 106 modp(&tmp); 107 return tmp; 108 } 109 else if (l>mid) 110 { 111 tmp=query_sum(l,r,2*o+1); 112 modp(&tmp); 113 return tmp; 114 } 115 else 116 { 117 long long t1=query_sum(l,mid,2*o); 118 long long t2=query_sum(mid+1,r,2*o+1); 119 tmp=(t1%p)+(t2%p); 120 modp(&tmp); 121 return tmp; 122 } 123 } 124 } 125 126 void update(int l,int r,long long o) 127 { 128 if (o>num) return; 129 int tl=t[o].l,tr=t[o].r; 130 if ((tl==l)&&(tr==r)) 131 { 132 /* 133 if (opr==2) 134 { 135 t[o].dat+=(tr-tl+1)*md; 136 add[o]+=md; 137 modp(&t[o].dat); 138 modp(&add[o]); 139 } 140 else if (opr==1) 141 { 142 t[o].dat=t[o].dat*md; 143 mul[o]=mul[o]*md; 144 modp(&t[o].dat); 145 modp(&mul[o]); 146 } 147 */ 148 if (opr==1) 149 { 150 t[o].dat=(t[o].dat*md)%p; 151 mul[o]=(mul[o]*md)%p; 152 add[o]=(add[o]*md)%p; 153 } 154 else if (opr==2) 155 { 156 t[o].dat=(t[o].dat+(tr-tl+1)*md)%p; 157 add[o]=(add[o]+md)%p; 158 } 159 return; 160 } 161 else 162 { 163 int mid=(tl+tr)/2; 164 //if (tl==tr) return; 165 166 push_down(o,tl,mid,tr); 167 168 if (r<=mid) 169 update(l,r,2*o); 170 else if (l>mid) 171 update(l,r,2*o+1); 172 else 173 { 174 update(l,mid,2*o); 175 update(mid+1,r,2*o+1); 176 } 177 t[o].dat=(t[2*o].dat+t[2*o+1].dat)%p; 178 } 179 } 180 181 void debug() 182 { 183 cout<<"This is the beginning"<<endl; 184 for (int i=1;i<=num;i++) 185 cout<<i<<" | "<<t[i].l<<" "<<t[i].r<<" "<<t[i].dat<<" |-and-| add="<<add[i]<<" mul="<<mul[i]<<endl; 186 cout<<"This is the ending"<<endl; 187 } 188 189 void debug1() 190 { 191 if (opr==3) 192 { 193 cout<<"The correct ans should be : "; 194 long long aw=0; 195 for (int i=ml;i<=mr;i++) 196 aw+=a[i]; 197 cout<<aw<<endl; 198 } 199 else 200 { 201 if (opr==1) 202 { 203 for (int i=ml;i<=mr;i++) 204 a[i]=a[i]*md; 205 } 206 else if (opr==2) 207 { 208 for (int i=ml;i<=mr;i++) 209 a[i]=a[i]+md; 210 } 211 cout<<"debug sequence: "; 212 for (int i=1;i<=n;i++) 213 cout<<a[i]<<" "; 214 cout<<endl; 215 } 216 } 217 218 int main() 219 { 220 freopen("seq.in","r",stdin); 221 freopen("seq.out","w",stdout); 222 223 scanf("%d %lld",&n,&p); 224 for (int i=1;i<=n;i++) 225 cin>>a[i]; 226 227 memset(t,0,sizeof(t)); 228 build(1,n,1); 229 //memset(add,0,sizeof(add)); 230 //memset(mul,0,sizeof(mul)); 231 // debug(); 232 233 cin>>m; 234 for (int i=1;i<=m;i++) 235 { 236 cin>>opr; 237 if (opr==3) 238 { 239 cin>>ml>>mr; 240 ans=query_sum(ml,mr,1); 241 //debug1(); 242 ans=ans%p; 243 cout<<ans<<endl; 244 } 245 else 246 { 247 cin>>ml>>mr>>md; 248 update(ml,mr,1); 249 //debug1(); 250 } 251 } 252 }
posted on 2014-09-04 09:55 Pentium.Labs 阅读(266) 评论(0) 编辑 收藏 举报