乘积求和
【题目描述】
给出N,以及N各数,求\(\sum_{l = 1}^N\sum_{r = 1}^N\sum_{i = l}^r\sum_{j = i + 1, a[j] < a[i]}^ra[i]*a[j]\)
输出答案对$10^{12} +7 $取模的结果
40%:\(N \leq 50\)
60%:\(N \leq100\)
80%:\(N\leq 1000\)
90%:\(1 \leq a_i \leq 10^5\)
100%:\(N\leq4 * 10^4,1 \leq a_i \leq 10^{12}\)
Solution
首先观察式子发现他让我们求序列所有子区间的严格逆序对乘积和
最简单的暴力枚举,\(N^4\)
手推一下发现,对于每一个数,都要与其后面的比他小的数相乘多次,这个次数就是包含他们的区间个数。那么显然有i * (n - j +1)个区间,其中 i 是大的数的位置, j 是小的数的位置。
这样是\(N^2\)的,可以得到80pt
#include <iostream>
#include <cstdio>
using namespace std;
inline long long read() {
long long x = 0; int f = 0; char c = getchar();
while (c < '0' || c > '9') f |= c == '-', c = getchar();
while (c >= '0' && c <= '9') x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
return f ? -x : x;
}
const long long mod = 1e12 + 7;
int n;
long long a[40004], ans;
int main() {
freopen("multiplication.in", "r", stdin);
freopen("multiplication.out", "w", stdout);
n = read();
for (int i = 1; i <= n; ++i) a[i] = read();
for (int i = 1; i <= n; ++i) {
long long sum = 0;
for (int j = i + 1; j <= n; ++j) {
if (a[j] < a[i]) ans = (ans + a[i] * a[j] % mod * (n - j + 1) % mod * i % mod) % mod;
}
}
printf("%lld\n", ans);
return 0;
}
然后我们发现对于每一个大的数,它的贡献是a[i] * i,每一个小的数的贡献是a[j] * (n - j + 1),他们两两没有关系。于是可以边加入边计算。
搞一棵权值线段树,每加入一个点,首先计算比他先插入的且比他大的数的贡献,然后再乘以(n - i + 1),再将a[i] * i 加入线段树就好了。
但会发现这么搞只有90pt,因为模数太大了,一乘就暴long long 了,所以再加一个快速龟速乘就好了
#include <iostream>
#include <cstdio>
using namespace std;
inline long long read() {
long long x = 0; int f = 0; char c = getchar();
while (c < '0' || c > '9') f |= c == '-', c = getchar();
while (c >= '0' && c <= '9') x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
return f ? -x : x;
}
struct szh {
int l, r;
long long sum;
szh() { l = 0; r = 0; sum = 0; }
}a[1600004];
#define mid ((l + r) >> 1)
const long long mod = 1e12 + 7;
int n, cnt = 1;
inline void pushup(int u) {
a[u].sum = (a[a[u].l].sum + a[a[u].r].sum) % mod;
}
inline long long get_sum(long long l, long long r, long long u, long long L, long long R) {
if (!u) return 0;
if (L <= l && r <= R){
return a[u].sum;
}
long long ans = 0;
if (L <= mid) ans = get_sum(l, mid, a[u].l, L, R);
if (mid < R) ans = (ans + get_sum(mid + 1, r, a[u].r, L ,R)) % mod;
return ans;
}
inline void add(long long l, long long r, long long u, long long p, long long v) {
if (l == r) {
a[u].sum = (a[u].sum + v) % mod; return;
}
if (p <= mid) {
if (!a[u].l) a[u].l = ++cnt;
add(l, mid, a[u].l, p, v);
}
else {
if (!a[u].r) a[u].r = ++cnt;
add(mid + 1, r, a[u].r, p, v);
}
pushup(u);
}
inline long long mul(long long x, long long b) {
long long ans = 0;
while (b) {
if (b & 1) ans = (ans + x) % mod;
x <<= 1; x %= mod;
b >>= 1;
}
return ans;
}
int main() {
freopen("multiplication.in", "r", stdin);
freopen("multiplication.out", "w", stdout);
n = read();
long long ans = 0;
for (int i = 1; i <= n; ++i) {
long long x = read();
long long xx = x;
x = mul(x, get_sum(1, mod, 1, x + 1, mod));
x = x * (n - i + 1) % mod;
ans = (ans + x) % mod;
add(1, mod, 1, xx, 1ll * xx * i % mod);
}
printf("%lld\n", ans);
return 0;
}