[BZOJ5291][BJOI2018]链上二次求和(线段树)

感觉自己做的麻烦了,但常数似乎不算差。(只是Luogu最慢的点不到2s本地要跑10+s)

感觉我的想法是最自然的,但不明白为什么网上似乎找不到这种做法。(不过当然所有的做法都是分类大讨论,而我的方法手算部分较为麻烦)

每次询问考虑每个位置的贡献,拆分成求所有长度<=R的区间的贡献次数和减去长度<L的区间贡献次数和。

分成两大类考虑,设当前考虑长度在[1,r]的所有区间,当前要计算数a[k]的贡献次数:

一: $2r\leq n$

  1.$k\leq r$ 观察所有包含k的长度不超过r的区间,发现答案为$1+2+...+i+i+i+...i=\frac{1}{2}[(2r+1)i-i^2]$

  2.$r<k<n-r+1$ 左右两边都可以延伸k的长度,于是答案为$1+2+...+r=\frac{r(r+1)}{2}$

  3.$k\geq n-r+1$ 和情况一类似,答案为$1+2+...+(n-i)+(n-i)+(n-i)+...=\frac{1}{2}[2nr-n^2-n+2r+(2n-2r+1)i-i^2]$

二:$2r>n$

  1.$k\leq n-r+1$观察发现和上面情况一是一样的:$\frac{1}{2}[(2r+1)i-i^2]$

  2.$n-r+1<k<n/2$

    答案为$1+2+...+i+i+...+i+(i-1)+(i-2)+...=\frac{1}{2}[2nr-n^2-r^2+r-n+(2n+n)i-2i^2]$

   $n/2\leq k<r$

    一波复杂的带入化简发现答案同上:$\frac{1}{2}[2nr-n^2-r^2+r-n+(2n+n)i-2i^2]$

  3.$k\geq r$ 观察发现和上面情况三是一样的:$\frac{1}{2}[2nr-n^2-n+2r+(2n-2r+1)i-i^2]$

于是分别维护$\sum a_i$,$\sum a_i*i$,$\sum a_i*i^2$即可。

  1 #include<cstdio>
  2 #include<algorithm>
  3 #define ls (x<<1)
  4 #define rs (ls|1)
  5 #define lson ls,L,mid
  6 #define rson rs,mid+1,R
  7 #define rep(i,l,r) for (int i=(l); i<=(r); i++)
  8 using namespace std;
  9 
 10 const int N=200010,mod=1e9+7,inv2=(mod+1)/2,inv6=(mod+1)/6;
 11 int n,m,op,l,r,x,a[N];
 12 
 13 int rd(){
 14     int x=0; char ch=getchar(); bool f=0;
 15     while (ch<'0' || ch>'9') f|=(ch=='-'),ch=getchar();
 16     while (ch>='0' && ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
 17     return f ? -x : x;
 18 }
 19 
 20 struct P{ int a[3],tag; }v[N<<2];
 21 inline void inc(int &x,int y){ x+=y; (x>=mod)?x-=mod:0; }
 22 
 23 P operator +(P a,P b){
 24     inc(a.a[0],b.a[0]); inc(a.a[1],b.a[1]);
 25     inc(a.a[2],b.a[2]); a.tag=0; return a;
 26 }
 27 
 28 int cal1(int x){ return 1ll*x*(x+1)/2%mod; }
 29 int cal2(int x){ return 1ll*x*(x+1)*(2*x+1)%mod*inv6%mod; }
 30 
 31 void put(int x,int L,int R,int k){
 32     inc(v[x].a[0],1ll*(R-L+1)*k%mod);
 33     inc(v[x].a[1],1ll*(cal1(R)-cal1(L-1)+mod)*k%mod);
 34     inc(v[x].a[2],1ll*(cal2(R)-cal2(L-1)+mod)*k%mod);
 35     inc(v[x].tag,k);
 36 }
 37 
 38 void push(int x,int L,int R){
 39     if (!v[x].tag) return;
 40     int mid=(L+R)>>1;
 41     put(lson,v[x].tag); put(rson,v[x].tag); v[x].tag=0;
 42 }
 43 
 44 void build(int x,int L,int R){
 45     if (L==R){
 46         v[x].a[0]=a[L]; v[x].a[1]=1ll*a[L]*L%mod;
 47         v[x].a[2]=1ll*a[L]*L%mod*L%mod; return;
 48     }
 49     int mid=(L+R)>>1;
 50     build(lson); build(rson); v[x]=v[ls]+v[rs];
 51 }
 52 
 53 void mdf(int x,int L,int R,int l,int r,int k){
 54     if (L==l && r==R){ put(x,L,R,k); return; }
 55     int mid=(L+R)>>1; push(x,L,R);
 56     if (r<=mid) mdf(lson,l,r,k);
 57     else if (l>mid) mdf(rson,l,r,k);
 58         else mdf(lson,l,mid,k),mdf(rson,mid+1,r,k);
 59     v[x]=v[ls]+v[rs];
 60 }
 61 
 62 P que(int x,int L,int R,int l,int r){
 63     if (L==l && r==R) return v[x];
 64     int mid=(L+R)>>1; push(x,L,R);
 65     if (r<=mid) return que(lson,l,r);
 66     else if (l>mid) return que(rson,l,r);
 67         else return que(lson,l,mid)+que(rson,mid+1,r);
 68 }
 69 
 70 int Q1(int d,int r){
 71     P t=que(1,1,n,1,d);
 72     return (1ll*t.a[1]*(2ll*r+1)%mod-t.a[2]+mod)%mod*inv2%mod;
 73 }
 74 
 75 int Q2(int d,int r){
 76     if (d>n) return 0;
 77     P t=que(1,1,n,d,n);
 78     return (((2ll*n*r-1ll*n*n-n+2ll*r)%mod*t.a[0]%mod+(2ll*n-2ll*r+1)%mod*t.a[1]%mod-t.a[2])%mod+mod)%mod*inv2%mod;
 79 }
 80 
 81 int Q3(int L,int R,int r){
 82     P t=que(1,1,n,L,R); return 1ll*r*(r+1)/2%mod*t.a[0]%mod;
 83 }
 84 
 85 int Q4(int L,int R,int r){
 86     if (L>R) return 0;
 87     P t=que(1,1,n,L,R);
 88     return (((2ll*n*r-1ll*n*n-1ll*r*r+r-n)%mod*t.a[0]%mod+(2ll*n+2)*t.a[1]%mod-2ll*t.a[2])%mod+mod)%mod*inv2%mod;
 89 }
 90 
 91 int solve(int r){
 92     if (!r) return 0;
 93     int L=min(r,n-r+1),R=max(r,n-r+1);
 94     if (L==R) R++;
 95     int r1=Q1(L,r),r2=Q2(R,r);
 96     int r3=(L+1<=R-1)?((r<=n-r+1)?Q3(L+1,R-1,r):Q4(L+1,R-1,r)):0;
 97     return (1ll*r1+r2+r3)%mod;
 98 }
 99 
100 int main(){
101     freopen("sum.in","r",stdin);
102     freopen("sum.out","w",stdout);
103     n=rd(); m=rd();
104     rep(i,1,n) a[i]=rd();
105     build(1,1,n);
106     while (m--){
107         op=rd(); l=rd(); r=rd();
108         if (l>r) swap(l,r);
109         if (op==1) x=rd(),mdf(1,1,n,l,r,x);
110             else printf("%d\n",(solve(r)-solve(l-1)+mod)%mod);
111     }
112     return 0;
113 }

 

posted @ 2019-02-16 09:44  HocRiser  阅读(310)  评论(0编辑  收藏  举报