杭电第八场 1004 counting stars (hdu7059
题意:给定n个数,有三种操作,1:询问l到r的和,2.ax=ax-lowbit(ax)3.ax二进制最高位左移1
思路:很明显就是一道数据结构题,最高位1左移非常好维护,区间和乘2就行,但是最低位1删除非常麻烦,很难在线段树上直接搞。
事实上只要把除最高位以外的和进行区间维护,然后对于每个2操作单点修改就行了,不管q多大,2操作最多把每个点修改nlogn次。
所以是当时nt了。
下附代码:
#include<bits/stdc++.h> #define ll long long #define p2 (p<<1) #define p3 (p<<1|1) using namespace std; const int maxn=1e5+5; const ll mod=998244353; vector<ll> num[maxn]; int cnt[maxn],fa[maxn]; ll a[maxn],x[maxn],n; ll tr[4*maxn],lz[4*maxn],sum[4*maxn]; int find(int x){ if (fa[x]==x) return fa[x]; else return fa[x]=find(fa[x]); } ll qpow(ll a,ll b){ ll ret=1; while (b){ if (b&1) ret=ret*a%mod; a=a*a%mod; b>>=1; } return ret; } void build(int l,int r,int p){ if (l==r){ tr[p]=x[l]; sum[p]=a[l]-x[l]; lz[p]=0; return; } int mid=(l+r)>>1; build(l,mid,p2); build(mid+1,r,p3); tr[p]=(tr[p2]+tr[p3])%mod; sum[p]=(sum[p2]+sum[p3])%mod; lz[p]=0; } void upd(int l,int r,int k,int p){ if (l==k && r==k){ tr[p]=0; return ; } if (lz[p]!=0){ tr[p2]=tr[p2]*qpow(2,lz[p])%mod,tr[p3]=tr[p3]*qpow(2,lz[p])%mod; lz[p2]=lz[p2]+lz[p]; lz[p3]=lz[p3]+lz[p]; lz[p]=0; } int mid=l+r>>1; if(k<=mid) upd(l,mid,k,p2); else upd(mid+1,r,k,p3); tr[p]=(tr[p2]+tr[p3])%mod; } void update(int l,int r,int le,int ri,int p){ if (l==le && r==ri){ tr[p]=tr[p]*2%mod; lz[p]++; return; } if (lz[p]!=0){ tr[p2]=tr[p2]*qpow(2,lz[p])%mod,tr[p3]=tr[p3]*qpow(2,lz[p])%mod; lz[p2]=lz[p2]+lz[p]; lz[p3]=lz[p3]+lz[p]; lz[p]=0; } int mid=l+r>>1; if (ri<=mid) update(l,mid,le,ri,p2); else if (le>mid) update(mid+1,r,le,ri,p3); else update(l,mid,le,mid,p2),update(mid+1,r,mid+1,ri,p3); tr[p]=(tr[p2]+tr[p3])%mod; } int query(int l,int r,int le,int ri,int p){ if (l==le && r==ri){ return tr[p]; } if (lz[p]!=0){ tr[p2]=tr[p2]*qpow(2,lz[p])%mod,tr[p3]=tr[p3]*qpow(2,lz[p])%mod; lz[p2]=lz[p2]+lz[p]; lz[p3]=lz[p3]+lz[p]; lz[p]=0; } int mid=l+r>>1,ret=0; if (ri<=mid) ret=query(l,mid,le,ri,p2); else if (le>mid) ret=query(mid+1,r,le,ri,p3); else ret=(query(l,mid,le,mid,p2)+query(mid+1,r,mid+1,ri,p3))%mod; tr[p]=(tr[p2]+tr[p3])%mod; return ret; } void change(int l,int r,int k,int p,int v){ if (l==k && r==k){ sum[p]=(sum[p]-v+mod)%mod; return ; } int mid=l+r>>1; if(k<=mid) change(l,mid,k,p2,v); else change(mid+1,r,k,p3,v); sum[p]=(sum[p2]+sum[p3])%mod; } int qsum(int l,int r,int le,int ri,int p){ if (l==le && r==ri){ return sum[p]; } int mid=l+r>>1,ret=0; if (ri<=mid) ret=qsum(l,mid,le,ri,p2); else if (le>mid) ret=qsum(mid+1,r,le,ri,p3); else ret=(qsum(l,mid,le,mid,p2)+qsum(mid+1,r,mid+1,ri,p3))%mod; sum[p]=(sum[p2]+sum[p3])%mod; return ret; } int main(){ int T; scanf("%d",&T); while (T--){ scanf("%d",&n); for (int i=1; i<=n; i++){ fa[i]=i; num[i].clear(); scanf("%d",&a[i]); cnt[i]=0; int tmp=a[i],cc=0; while (tmp){ if (tmp&1){ cnt[i]++; num[i].push_back(1<<cc); } tmp/=2; cc++; } x[i]=num[i][num[i].size()-1]; num[i].pop_back(); reverse(num[i].begin(),num[i].end()); } build(1,n,1); int q; scanf("%d",&q); while (q--){ int c,l,r; scanf("%d%d%d",&c,&l,&r); if (c==1){ ll res=(query(1,n,l,r,1)+qsum(1,n,l,r,1))%mod; printf("%lld\n",res); } else if(c==2){ for (int i=l; i<=r; i++){ int x=find(i); if (x>r) break; cnt[x]--; if (cnt[x]==0){ upd(1,n,x,1); int v=find(x+1); fa[x]=v; } else change(1,n,x,1,num[x][cnt[x]-1]); i=x; } } else { update(1,n,l,r,1); } } } }