Luogu P6477 [NOI Online #2 提高组]子序列问题
\(\large{题目链接}\)
\(\\\)
题意:
给定一个长度为\(n\)的正整数序列,定义函数\(f_{l,r}\)表示在下标在\(\left[l,r\right]\)的子区间中不同整数的个数。
求:\(\sum \limits^{n}_{l=1} \sum \limits ^{n}_{r=l}f\left( l,r\right)^{2} \left(\mod 1e9 + 7\right)\)
\(1 \leq n \leq 10^6\)
\(\\\)
思路:
首先看到\(10^9\)的值域,而且关心的只是数值相等不相等,与具体值无关,先离散化一下。
我们枚举左端点\(l\),考虑当左端点为\(l\)的区间对答案的贡献,把这些贡献全部加在一起就是最终的答案。
那么题目就变成求 \(\sum \limits _{i = 1} ^ {n} f(l,i)^2\)。
因为\(n\)的范围是\(10^6\),显然要找到一种方法能够维护答案。
对于\(\left[l,n\right]\)中出现过的数\(x\),设它在\(\left[l,n\right]\)出现的最左位置为\(pos_x\)。记\(t_i\)为\(f(l,i)\)的值。
考虑倒序循环\(l\),那么左端点由\(l+1\)变为\(l\)的时候,会发生两种事。
1.\(t_l,t_{l+1},...,t_{pos_x-1}\)都加1。
2.\(pos_x\)变为l。
那么所需要解决的问题就变为了:
1.支持区间修改。
2.求区间的平方和。
可以用线段树维护。如果区间加上\(k\),那么平方和变为:
\[\left( a_{l}+k\right) ^{2}+\left( a_{l+1}+k\right) ^{2}+\ldots +\left( a_{r}+k\right) ^{2}
\]
\[= a^{2}_{l}+2ka_{l} + k^{2} + a^{2}_{l+1}+2ka_{l+1} + k^{2} +...+ a^{2}_{r}+2ka_{r} + k^{2}
\]
\[= \left( a^{2}_{l}+a^{2}_{l+1}+\ldots +a^{2}_{r}\right) + 2k(a_{l}+a_{l+1}+\ldots +a_{r}) + (r- l+ 1)\times k ^ 2
\]
维护区间和和区间平方和即可。
\(\\\)
代码:
#include <bits/stdc++.h>
#define ls (x << 1)
#define rs (x << 1 | 1)
using namespace std;
typedef long long ll;
const int N = 1e6 + 5;
const int p = 1e9 + 7;
int n, a[N], pos[N];
struct Node {
int id, val;
}b[N];
int read() {
int x = 0;
char c = getchar();
for (; !isdigit(c); c = getchar());
for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
return x;
}
bool cmp(Node x, Node y) { return x.val < y.val; }
struct Segment_tree {
int tl[N << 2], tr[N << 2];
ll t[N << 2], lz[N << 2], s[N << 2];
void build(int x, int l, int r) {
tl[x] = l, tr[x] = r;
if (l == r) return;
int mid = (l + r) >> 1;
build(ls, l, mid);
build(rs, mid + 1, r);
}
void up(int x) {
s[x] = s[ls] + s[rs];
if (s[x] > p) s[x] -= p;
t[x] = t[ls] + t[rs];
if (t[x] > p) t[x] -= p;
}
void down(int x) {
if (!lz[x]) return;
s[ls] = (s[ls] + 2 * lz[x] * t[ls] % p + (tr[ls] - tl[ls] + 1) * lz[x] * lz[x] % p) % p;
t[ls] = (t[ls] + (tr[ls] - tl[ls] + 1) * lz[x] % p) % p;
lz[ls] += lz[x];
s[rs] = (s[rs] + 2 * lz[x] * t[rs] % p + (tr[rs] - tl[rs] + 1) * lz[x] * lz[x] % p) % p;
t[rs] = (t[rs] + (tr[rs] - tl[rs] + 1) * lz[x] % p) % p;
lz[rs] += lz[x];
lz[x] = 0;
}
void update(int x, int l, int r, ll k) {
if (l <= tl[x] && r >= tr[x]) {
s[x] = (s[x] + 2 * k * t[x] % p + (tr[x] - tl[x] + 1) * k * k % p) % p;
t[x] = (t[x] + (tr[x] - tl[x] + 1) * k % p) % p;
lz[x] += k;
return;
}
down(x);
int mid = (tl[x] + tr[x]) >> 1;
if (l <= mid) update(ls, l, r, k);
if (r >= mid + 1) update(rs, l, r, k);
up(x);
}
ll query(int x, int l, int r) {
if (l <= tl[x] && r >= tr[x]) return s[x];
ll ret = 0;
int mid = (tl[x] + tr[x]) >> 1;
if (l <= mid) ret = query(ls, l, r);
if (r >= mid + 1) ret = (ret + query(rs, l, r)) % p;
return ret;
}
}T;
int main() {
n = read();
for (int i = 1; i <= n; ++i) b[i].id = i;
for (int i = 1; i <= n; ++i) b[i].val = read();
sort(b + 1, b + 1 + n, cmp);
int cnt = 0;
b[0].val = b[1].val - 1;
for (int i = 1; i <= n; ++i) b[i].val == b[i - 1].val ? a[b[i].id] = cnt : a[b[i].id] = ++cnt;
for (int i = 1; i <= cnt; ++i) pos[i] = n + 1;
T.build(1, 1, n);
ll ans = 0;
for (int i = n; i >= 1; --i) {
T.update(1, i, pos[a[i]] - 1, 1);
ans = (ans + T.query(1, i, n)) % p;
pos[a[i]] = i;
}
printf("%lld\n", ans);
return 0;
}