【CF266E】More Queries to Array... - 线段树
题目描述
You've got an array, consisting of nn integers: \(a_{1},a_{2},...,a_{n}\). Your task is to quickly run the queries of two types:
-
Assign value \(x\) to all elements from \(l\) to \(r\) inclusive. After such query the values of the elements of array \(a_{l},a_{l+1},...,a_{r}\) become equal to \(x\).
-
Calculate and print sum , where \(k\) doesn't exceed \(5\) . As the value of the sum can be rather large, you should print it modulo \(1000000007 (10^{9}+7)\)
题目大意
一段序列 \(a_1,a_2......a_n\)
维护两种操作:
\(=\ l\ r\ x\) 表示将区间 \([l,r]\) 的值赋为 \(x\)
\(?\ l\ r\ k\) 表示输出 \(\sum_{i=l}^ra_i(i-l+1)^k\ mod\ 1e9+7\)
思路
用二项式定理展开一下
\[\begin{align*}
\sum_{i=l}^ra_i[i+(1-l)]^k\\
\end{align*}\]
\[\begin{align*}
=&\sum_{i=l}^ra_i\sum_{j=0}^ki^j(1-l)^{k-j}C_k^j\\
\end{align*}\]
\[\begin{align*}
=&\sum_{j=0}^k(1-l)^{k-j}C_k^j\sum_{i=l}^ra_ii^j
\end{align*}\]
所以维护 \(a_ii^k,k\in[0,5]\) 就好了
#include <cstdio>
const int c[6][6] = { { 1,0,0,0,0,0 },{ 1,1,0,0,0,0 },{ 1,2,1,0,0,0 },{ 1,3,3,1,0,0 },{ 1,4,6,4,1,0 },{ 1,5,10,10,5,1 } };
const int maxn = 1e5 + 10;
const int mod = 1e9 + 7;
typedef long long ll;
int n,m,laz[maxn<<3];
ll sum[maxn<<3][6];
inline ll powerkth(ll n,int k) {
if (k == 1) return n*(n+1)/2%mod;
if (k == 2) return n*(n+1)*(2*n+1)/6%mod;
if (k == 3) return n*n%mod*(n+1)%mod*(n+1)%mod*250000002ll%mod;
if (k == 4) return n*(n+1)%mod*(2*n+1)%mod*(3*n*n%mod+3*n%mod-1)%mod*233333335ll%mod;
if (k == 5) return n*n%mod*(n+1)%mod*(n+1)%mod*(2*n*n%mod+2*n%mod-1)%mod*83333334ll%mod;
return n;
}
inline void pushup(int root) { for (int i = 0;i <= 5;i++) sum[root][i] = (sum[root<<1][i]+sum[root<<1|1][i])%mod; }
inline void pushdown(int root,int l,int r) {
int mid = l+r>>1;
if (laz[root] ^ mod) {
laz[root<<1] = laz[root];
laz[root<<1|1] = laz[root];
for (int i = 0;i <= 5;i++) {
sum[root<<1][i] = laz[root]*((powerkth(mid,i)-powerkth(l-1,i)+mod)%mod)%mod;
sum[root<<1|1][i] = laz[root]*((powerkth(r,i)-powerkth(mid,i)+mod)%mod)%mod;
}
laz[root] = mod;
}
}
inline void build(int l,int r,int root) {
laz[root] = mod;
if (l == r) {
scanf("%lld",&sum[root][0]);
for (int i = 1;i <= 5;i++) sum[root][i] = sum[root][i-1]*l%mod;
return;
}
int mid = l+r>>1;
build(l,mid,root<<1);
build(mid+1,r,root<<1|1);
pushup(root);
}
inline void update(int l,int r,int ul,int ur,int root,ll x) {
if (l > ur || r < ul) return;
if (ul <= l && r <= ur) {
laz[root] = x;
for (int i = 0;i <= 5;i++) sum[root][i] = x*((powerkth(r,i)-powerkth(l-1,i)+mod)%mod)%mod;
return;
}
pushdown(root,l,r);
int mid = l+r>>1;
update(l,mid,ul,ur,root<<1,x);
update(mid+1,r,ul,ur,root<<1|1,x);
pushup(root);
}
inline ll query(int l,int r,int ql,int qr,int root,int k) {
if (l > qr || r < ql) return 0;
if (ql <= l && r <= qr) return sum[root][k];
pushdown(root,l,r);
int mid = l+r>>1;
return (query(l,mid,ql,qr,root<<1,k)+query(mid+1,r,ql,qr,root<<1|1,k))%mod;
}
int main() {
for (scanf("%d%d",&n,&m),build(1,n,1);m--;) {
char ch; int l,r,k;
scanf("%s%d%d%d",&ch,&l,&r,&k);
if (ch == '=') update(1,n,l,r,1,k);
else {
if (l == 1) { printf("%lld\n",query(1,n,l,r,1,k)); continue; }
ll ans = 0;
for (int i = 0;i <= k;i++) {
ll tmp = 1;
for (int j = 1;j <= k-i;j++) tmp = (tmp*(1-l)%mod+mod)%mod;
(ans += tmp*c[k][i]%mod*query(1,n,l,r,1,i)%mod) %= mod;
}
printf("%lld\n",ans);
}
}
return 0;
}