洛谷P4458 /loj#2512.[BJOI2018]链上二次求和(线段树)
题面
题解
我果然是人傻常数大的典型啊……
//minamoto
#include<bits/stdc++.h>
#define R register
#define ls (p<<1)
#define rs (p<<1|1)
#define fp(i,a,b) for(R int i=a,I=b+1;i<I;++i)
#define fd(i,a,b) for(R int i=a,I=b-1;i>I;--i)
#define go(u) for(int i=head[u],v=e[i].v;i;i=e[i].nx,v=e[i].v)
template<class T>inline bool cmin(T&a,const T&b){return a>b?a=b,1:0;}
using namespace std;
char buf[1<<21],*p1=buf,*p2=buf;
inline char getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
int read(){
R int res,f=1;R char ch;
while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1);
for(res=ch-'0';(ch=getc())>='0'&&ch<='9';res=res*10+ch-'0');
return res*f;
}
char sr[1<<21],z[20];int K=-1,Z=0;
inline void Ot(){fwrite(sr,1,K+1,stdout),K=-1;}
void print(R int x){
if(K>1<<20)Ot();if(x<0)sr[++K]='-',x=-x;
while(z[++Z]=x%10+48,x/=10);
while(sr[++K]=z[Z],--Z);sr[++K]='\n';
}
const int N=2e5+5,P=1e9+7,inv2=500000004,inv6=166666668;
inline int add(R int x,R int y){return x+y>=P?x+y-P:x+y;}
inline int dec(R int x,R int y){return x-y<0?x-y+P:x-y;}
inline int mul(R int x,R int y){return 1ll*x*y-1ll*x*y/P*P;}
int ksm(R int x,R int y){
R int res=1;
for(;y;y>>=1,x=mul(x,x))if(y&1)res=mul(res,x);
return res;
}
struct node{int l,r,a,b,sum[2];}tr[N<<2];
int a[N],n,m;
inline int f(R int x){return 1ll*x*(x+1)%P*((x<<1ll)+1)%P*inv6%P;}
void ppd(int p,int a,int b){
int sz=tr[p].r-tr[p].l;
tr[p].sum[0]=add(tr[p].sum[0],1ll*add(mul(2,a),mul(b,sz))*(sz+1)%P*inv2%P);
tr[p].sum[1]=add(tr[p].sum[1],1ll*a*(tr[p].l+tr[p].r)%P*(sz+1)%P*inv2%P);
tr[p].sum[1]=add(tr[p].sum[1],1ll*tr[p].l*b%P*sz%P*(sz+1)%P*inv2%P);
tr[p].sum[1]=add(tr[p].sum[1],mul(b,f(sz)));
tr[p].a=add(tr[p].a,a),tr[p].b=add(tr[p].b,b);
}
void upd(int p){
tr[p].sum[0]=add(tr[ls].sum[0],tr[rs].sum[0]);
tr[p].sum[1]=add(tr[ls].sum[1],tr[rs].sum[1]);
}
void pd(int p){
int mid=(tr[p].l+tr[p].r)>>1;
ppd(ls,tr[p].a,tr[p].b);
ppd(rs,add(tr[p].a,mul(mid-tr[p].l+1,tr[p].b)),tr[p].b);
tr[p].a=tr[p].b=0;
}
void build(int p,int l,int r){
tr[p].l=l,tr[p].r=r;
if(l==r)return tr[p].sum[0]=a[l],tr[p].sum[1]=mul(l,a[l]),void();
int mid=(l+r)>>1;
build(ls,l,mid),build(rs,mid+1,r);
upd(p);
}
int query(int p,int ql,int qr,int id){
if(ql<=tr[p].l&&qr>=tr[p].r)return tr[p].sum[id];
int mid=(tr[p].l+tr[p].r)>>1;
if(tr[p].a||tr[p].b)pd(p);
int res=0;
if(ql<=mid)res=add(res,query(ls,ql,qr,id));
if(qr>mid)res=add(res,query(rs,ql,qr,id));
return res;
}
void update(int p,int l,int r,int a,int b){
if(l==tr[p].l&&r==tr[p].r)return ppd(p,a,b);
if(tr[p].a||tr[p].b)pd(p);
int mid=(tr[p].l+tr[p].r)>>1;
if(r<=mid)update(ls,l,r,a,b);
else if(l>mid)update(rs,l,r,a,b);
else update(ls,l,mid,a,b),update(rs,mid+1,r,add(a,mul(b,mid+1-l)),b);
upd(p);
}
int calc(int x){
if(!x)return 0;
int res=0;
res=add(res,mul(query(1,x,n,0),x));
if(x!=n)res=add(res,P-mul(query(1,1,n-x,0),x));
if(x!=1)res=add(res,query(1,1,x-1,1));
res=add(res,P-add(mul(n,query(1,n-x+1,n,0)),P-query(1,n-x+1,n,1)));
return res;
}
int main(){
// freopen("testdata.in","r",stdin);
n=read(),m=read();
fp(i,1,n)a[i]=read(),a[i]=add(a[i],a[i-1]);
build(1,1,n);
int op,l,r,d;
while(m--){
op=read(),l=read(),r=read();
if(op==1){
d=read();if(l>r)swap(l,r);
update(1,l,r,d,d);
if(r<n)update(1,r+1,n,mul(d,r-l+1),0);
}else print(dec(calc(r),calc(l-1)));
}
return Ot(),0;
}
深深地明白自己的弱小