数据结构模板整理
树状数组
单点修改,区间询问
#include<iostream> #include<cstdio> using namespace std; int read() //读入优化 { char ch=getchar(); int a=0,x=1; while(ch<'0'||ch>'9') { if(ch=='-') x=-x; ch=getchar(); } while(ch>='0'&&ch<='9') { a=(a<<3)+(a<<1)+(ch-'0'); ch=getchar(); } return a*x; } int n,m,k,x,y; //n个数,m次操作 //k=1,第x个数加上y //k=0,问区间[x,y]的和
int a[500001],c[500001];
int lowbit(int x) //求lowbit { return x&(-x); }
void update(int x,int y) { for(;x<=n;x+=lowbit(x)) c[x]+=y; //第x个数加上y }
int sum(int x) //求区间[1,x]的和 { int ans=0; for(;x;x-=lowbit(x)) ans+=c[x]; return ans; }
int main() { n=read();m=read(); for(int i=1;i<=n;i++) { a[i]=read(); update(i,a[i]); //先建好树 } for(int i=1;i<=m;i++) { k=read();x=read();y=read(); if(k==1) update(x,y); else printf("%d\n",sum(y)-sum(x-1)); //前缀和做差 } return 0; }
区间修改,单点询问
#include<iostream> #include<cstdio> using namespace std; int read() //读入优化 { char ch=getchar(); int a=0,x=1; while(ch<'0'||ch>'9') { if(ch=='-') x=-x; ch=getchar(); } while(ch>='0'&&ch<='9') { a=(a<<3)+(a<<1)+(ch-'0'); ch=getchar(); } return a*x; } int n,m,k,x,y,v; //n个数,m次操作 //k=1,区间[x,y]加上v //k=2,询问第x个数的值 int a[500001],c[500001],d[500001]; //a存原数列,c是树状数组,d是差分数组 int lowbit(int x) //求lowbit { return x&(-x); } void update(int x,int y) { for(;x<=n;x+=lowbit(x)) c[x]+=y; //第x个数加上y } int sum(int x) //求区间[1,x]的和 { int ans=0; for(;x;x-=lowbit(x)) ans+=c[x]; return ans; } int main() { n=read();m=read(); for(int i=1;i<=n;i++) { a[i]=read(); d[i]=a[i]-a[i-1]; //求差分数组 update(i,d[i]); //建树 } for(int i=1;i<=m;i++) { k=read(); if(k==1) { x=read();y=read();v=read(); //区间[x,y]加上v update(x,v); //差分数组的变化 update(y+1,-v); } else { x=read(); printf("%d\n",sum(x)); } } return 0; }
线段树
单点修改,区间询问
#include<iostream> #include<cstdio> using namespace std; int read() //读入优化 { char ch=getchar(); int a=0,x=1; while(ch<'0'||ch>'9') { if(ch=='-') x=-x; ch=getchar(); } while(ch>='0'&&ch<='9') { a=(a<<3)+(a<<1)+(ch-'0'); ch=getchar(); } return a*x; } int n,m,k,x,y; //n个数,m次操作 //k=1,第x个数加上y //k=2,询问区间[x,y]的和 int a[500001],sum[500001]; //a是原数列,sum是维护的区间和 void build(int node,int l,int r) //建树 { if(l==r) //叶子结点 { sum[node]=a[l]; //直接赋值 return ; } int mid=(l+r)>>1; build(node*2,l,mid); //分别建好左右子树 build(node*2+1,mid+1,r); sum[node]=sum[node*2]+sum[node*2+1]; //加起来就是根结点的区间和 } void add(int node,int l,int r,int x,int k) //给第x个数加上k { if(l==r&&l==x) //找到了叶子结点且正好是区间[x,x] { sum[node]+=k; return ; } int mid=(l+r)>>1; if(x<=mid) add(node*2,l,mid,x,k); //看是否在左子树里 else add(node*2+1,mid+1,r,x,k); //否则就在右子树里,注意这里能用else是因为这是单点修改 sum[node]=sum[node*2]+sum[node*2+1]; } int ask(int node,int l,int r,int x,int y) //询问区间和 { if(x<=l&&r<=y) return sum[node]; //[l,r]被完全包含在[x,y]内的话直接返回 int mid=(l+r)>>1; int rnt=0; if(x<=mid) rnt+=ask(node*2,l,mid,x,y); //找左右子树是否有交集 if(y>mid) rnt+=ask(node*2+1,mid+1,r,x,y); return rnt; } int main() { n=read();m=read(); for(int i=1;i<=n;i++) a[i]=read(); build(1,1,n); //建树 for(int i=1;i<=m;i++) { k=read();x=read();y=read(); if(k==1) add(1,1,n,x,y); else printf("%d\n",ask(1,1,n,x,y)); } return 0; }
区间修改,区间询问
#include<iostream> #include<cstdio> #include<algorithm> #include<queue> using namespace std; long long read() { char ch=getchar(); long long a=0,x=1; while(ch<'0'||ch>'9') { if(ch=='-') x=-x; ch=getchar(); } while(ch>='0'&&ch<='9') { a=(a<<1)+(a<<3)+(ch-'0'); ch=getchar(); } return a*x; } const long long N=200005; long long n,m,oper,x,y,z; long long sum[N<<2],lazy[N<<2],a[N]; void build(long long node,long long l,long long r) { if(l==r) { sum[node]=a[l]; return ; } long long mid=(l+r)>>1; build(node<<1,l,mid); build(node<<1|1,mid+1,r); sum[node]=sum[node<<1]+sum[node<<1|1]; } void pushdown(long long node,long long l,long long r) { if(!lazy[node]) return ; lazy[node<<1]+=lazy[node]; lazy[node<<1|1]+=lazy[node]; long long mid=(l+r)>>1; sum[node<<1]+=lazy[node]*(mid-l+1); sum[node<<1|1]+=lazy[node]*(r-mid); lazy[node]=0; } void add(long long node,long long l,long long r,long long x,long long y,long long v) { if(x<=l&&r<=y) { lazy[node]+=v; sum[node]+=(r-l+1)*v; return ; } pushdown(node,l,r); long long mid=(l+r)>>1; if(x<=mid) add(node<<1,l,mid,x,y,v); if(y>mid) add(node<<1|1,mid+1,r,x,y,v); sum[node]=sum[node<<1]+sum[node<<1|1]; } long long query(long long node,long long l,long long r,long long x,long long y) { if(x<=l&&r<=y) return sum[node]; pushdown(node,l,r); long long mid=(l+r)>>1; long long res=0; if(x<=mid) res+=query(node<<1,l,mid,x,y); if(y>mid) res+=query(node<<1|1,mid+1,r,x,y); return res; } int main() { n=read();m=read(); for(long long i=1;i<=n;i++) a[i]=read(); build(1,1,n); for(long long i=1;i<=m;i++) { oper=read();x=read();y=read(); if(oper==1) { z=read(); add(1,1,n,x,y,z); } else printf("%lld\n",query(1,1,n,x,y)); } return 0; }
ST表
#include<iostream> #include<cstdio> #include<cmath> using namespace std; int a[100001],f[100001][20]; int read() { char ch=getchar(); long long a=0; while(ch<'0'||ch>'9') ch=getchar(); while(ch>='0'&&ch<='9') { a=a*10+(ch-'0'); ch=getchar(); } return a; } int main() { int n,m; n=read(),m=read(); for(int i=1;i<=n;i++) { a[i]=read(); f[i][0]=a[i]; //初始化 } for(int j=1;(1<<j)<=n;j++) //注意j在外层 for(int i=1;i+(1<<j)-1<=n;i++) f[i][j]=max(f[i][j-1],f[i+(1<<(j-1))][j-1]); //状态转移方程 for(int i=1;i<=m;i++) { int l=read(); int r=read(); int k=(int)(log((double)(r-l+1))/log(2.0)); int ans=max(f[l][k],f[r-(1<<k)+1][k]); printf("%d\n",ans); } return 0; }
最近公共祖先LCA
#include<iostream> #include<cstdio> #include<cstring> using namespace std; const int maxn=500001; int head[2*maxn],to[2*maxn],next[2*maxn],grand[2*maxn][21],dep[maxn]; int n,m,s,edge_sum=0; void add(int x,int y) //链表存图 { next[++edge_sum]=head[x]; head[x]=edge_sum; to[edge_sum]=y; } void dfs(int v,int deep) //dfs求出每个点的深度 { dep[v]=deep; for(int i=head[v];i>0;i=next[i]) { int u=to[i]; if(!dep[u]) dfs(u,deep+1),grand[u][0]=v; } } int lca(int x,int y) { if(dep[x]<dep[y]) swap(x,y); //让x的深度大于y for(int i=20;i>=0;i--) //跳到同一深度 if(dep[y]<=dep[x]-(1<<i)) x=grand[x][i]; if(x==y) return y; for(int i=20;i>=0;i--) { if(grand[x][i]!=grand[y][i]) //跳不到同一点就往上跳 { x=grand[x][i]; y=grand[y][i]; } } return grand[x][0]; //最后再跳一下肯定是LCA } int read() { char ch=getchar(); int a=0; while(ch<'0'||ch>'9') ch=getchar(); while(ch>='0'&&ch<='9') { a=a*10+(ch-'0'); ch=getchar(); } return a; } int main() { memset(head,0,sizeof(head)); n=read(),m=read(),s=read(); for(int i=1;i<n;i++) { int x=read(),y=read(); add(x,y); add(y,x); } grand[s][0]=s; dfs(s,1); for(int i=1;(1<<i)<=n;i++) for(int j=1;j<=n;j++) grand[j][i]=grand[grand[j][i-1]][i-1]; //状态转移方程 for(int i=1;i<=m;i++) { int x=read(),y=read(); printf("%d\n",lca(x,y)); } return 0; }
分块
#include<iostream> #include<cstdio> #include<cmath> using namespace std; const int N=1e5+5; int n,m,len,sum,opt; int L[N],R[N],pos[N],tag[N]; long long a[N],Sum[N]; void insert(int l,int r,long long x) { int p=pos[l]; int q=pos[r]; if(p==q) { for(int i=l;i<=r;i++) a[i]+=x; Sum[p]+=(r-l+1)*x; } else { for(int i=p+1;i<=q-1;i++) tag[i]+=x; for(int i=l;i<=R[p];i++) a[i]+=x; Sum[p]+=(R[p]-l+1)*x; for(int i=L[q];i<=r;i++) a[i]+=x; Sum[q]+=(r-L[q]+1)*x; } } long long query(int l,int r) { int p=pos[l]; int q=pos[r]; long long ans=0; if(p==q) { for(int i=l;i<=r;i++) ans+=a[i]; ans+=(r-l+1)*tag[p]; } else { for(int i=p+1;i<=q-1;i++) ans+=Sum[i]+(R[i]-L[i]+1)*tag[i]; for(int i=l;i<=R[p];i++) ans+=a[i]; ans+=(R[p]-l+1)*tag[p]; for(int i=L[q];i<=r;i++) ans+=a[i]; ans+=(r-L[q]+1)*tag[q]; } return ans; } int main() { scanf("%d %d",&n,&m); for(int i=1;i<=n;i++) scanf("%lld",&a[i]); len=sqrt(n); //每一块的长度,通常取√n sum=n/len; //块数 for(int i=1;i<=sum;i++) //枚举每一块 { L[i]=(i-1)*len+1; //这一块前面有(i-1)*len个,所以这一块的左端点是第(i-1)*len+1个数 R[i]=i*len; //右端点同理 for(int j=L[i];j<=R[i];j++) //枚举块内的每个数 { pos[j]=i; //预处理每个数属于哪一个块 Sum[i]+=a[j]; //求出每一个块的总和 } } if(R[sum]<n) //如果上面的块并不能覆盖整个数列,我们需要再在最后加上一个块 { sum++; //块数+1 L[sum]=R[sum-1]+1; //这个块的左端点是上一个块的右端点+1 R[sum]=n; //这里末尾块的右端点要设置为n,保证正好覆盖整个数列 for(int i=L[sum];i<=R[sum];i++) { pos[i]=sum; Sum[sum]+=a[i]; } } for(int i=1;i<=m;i++) { int l,r,x; scanf("%d %d %d",&opt,&l,&r); if(opt==1) { scanf("%d",&x); insert(l,r,x); } else printf("%lld\n",query(l,r)); } return 0; }
无旋平衡树(fhq treap)
#include<iostream> #include<ctime> #include<cstdio> #include<cstdlib> using namespace std; int read() { char ch=getchar(); int a=0,x=1; while(ch<'0'||ch>'9') { if(ch=='-') x=-x; ch=getchar(); } while(ch>='0'&&ch<='9') { a=(a<<1)+(a<<3)+(ch-'0'); ch=getchar(); } return a*x; } const int N=1e5+5; int cnt; int size[N],val[N],rnd[N],ch[N][2]; int New(int x) { cnt++; val[cnt]=x; rnd[cnt]=rand(); size[cnt]=1; ch[cnt][0]=ch[cnt][1]=0; return cnt; } void update(int rt) { size[rt]=size[ch[rt][0]]+size[ch[rt][1]]+1; } void split(int rt,int k,int &x,int &y) { if(!rt) { x=y=0; return ; } if(val[rt]<=k) { x=rt; split(ch[rt][1],k,ch[rt][1],y); } else { y=rt; split(ch[rt][0],k,x,ch[rt][0]); } update(rt); } int merge(int x,int y) { if(!x||!y) return x+y; if(rnd[x]<=rnd[y]) { ch[x][1]=merge(ch[x][1],y); update(x); return x; } else { ch[y][0]=merge(x,ch[y][0]); update(y); return y; } } int kth(int rt,int k) { if(size[ch[rt][0]]+1==k) return val[rt]; if(size[ch[rt][0]]>=k) return kth(ch[rt][0],k); else return kth(ch[rt][1],k-size[ch[rt][0]]-1); } int n,k,rt,oper; int main() { srand(time(0)); n=read(); while(n--) { oper=read();k=read(); int x,y,z; if(oper==1) { split(rt,k,x,y); rt=merge(merge(x,New(k)),y); } if(oper==2) { split(rt,k,x,y); split(x,k-1,x,z); z=merge(ch[z][0],ch[z][1]); rt=merge(merge(x,z),y); } if(oper==3) { split(rt,k-1,x,y); printf("%d\n",size[x]+1); rt=merge(x,y); } if(oper==4) { printf("%d\n",kth(rt,k)); } if(oper==5) { split(rt,k-1,x,y); printf("%d\n",kth(x,size[x])); rt=merge(x,y); } if(oper==6) { split(rt,k,x,y); printf("%d\n",kth(y,1)); rt=merge(x,y); } } return 0; }
文艺平衡树(区间翻转)
#include<iostream> #include<cstdlib> #include<cstdio> #include<ctime> using namespace std; int read() { char ch=getchar(); int a=0,x=1; while(ch<'0'||ch>'9') { if(ch=='-') x=-x; ch=getchar(); } while(ch>='0'&&ch<='9') { a=(a<<1)+(a<<3)+(ch-'0'); ch=getchar(); } return a*x; } const int N=1e5+5; int n,m,k,l,r,rt,cnt; int val[N],size[N],ch[N][2],rnd[N],lazy[N]; void update(int node) { size[node]=size[ch[node][0]]+size[ch[node][1]]+1; } int New(int x) { ++cnt; rnd[cnt]=rand(); size[cnt]=1; val[cnt]=x; return cnt; } void pushdown(int node) { if(!lazy[node]) return ; swap(ch[node][0],ch[node][1]); lazy[ch[node][0]]^=1; lazy[ch[node][1]]^=1; lazy[node]=0; } void split(int node,int k,int &x,int &y) { if(!node) { x=y=0; return ; } pushdown(node); if(size[ch[node][0]]<k) { x=node; split(ch[node][1],k-size[ch[node][0]]-1,ch[node][1],y); } else { y=node; split(ch[node][0],k,x,ch[node][0]); } update(node); } int merge(int x,int y) { if(!x||!y) return x+y; if(rnd[x]<=rnd[y]) { pushdown(x); ch[x][1]=merge(ch[x][1],y); update(x); return x; } else { pushdown(y); ch[y][0]=merge(x,ch[y][0]); update(y); return y; } } void dfs(int node) { if(!node) return ; pushdown(node); dfs(ch[node][0]); printf("%d ",val[node]); dfs(ch[node][1]); } int main() { srand(time(0)); n=read();m=read(); for(int i=1;i<=n;i++) rt=merge(rt,New(i)); int x,y,z; for(int i=1;i<=m;i++) { l=read(); r=read(); split(rt,l-1,x,y); split(y,r-l+1,y,z); lazy[y]^=1; rt=merge(merge(x,y),z); } dfs(rt); return 0; }
快读模板
int read() { char ch=getchar(); int a=0,x=1; while(ch<'0'||ch>'9') { if(ch=='-') x=-x; ch=getchar(); } while(ch>='0'&&ch<='9') { a=(a<<3)+(a<<1)+(ch-'0'); ch=getchar(); } return a*x; }
O(1)求二进制数1的个数
int bsrun(int x) { int tmp=x - ((x>>1) &033333333333)-((x>>2) &011111111111); return((tmp+(tmp>>3)) &030707070707) %63; }