[NOI Online #2 提高组]子序列问题
没啥意思的板子题,HH的项链既视感…
反正我没打这个 NOIOL,随便写下好了,没有心路历程
别问为啥没打,问就是周六还呆在学校上文化课
首先这个显然是 HH的项链,如果你把 \(f(l,r)^2\) 变成 \(f(l,r)\) 就是 HH的项链变成区间求和…如果是 \(f(l,r)^2\) 就维护区间平方和的板子。
我们考虑所有以 \(r\) 为右端点的所有 \(\sum f_{i,r}\),假设 \(f_{i,r-1}\) 求的是对的,那么你 \(r\) 的贡献区间是 \([pre_r + 1, r]\),然后 HH的项链的套路这题就做完了啊
因为
\((x+y)^2 = x^2 + y^2 + 2xy\)
所以
只需要记录区间和的 \(sum_p\),以及区间平方和的 \(ans_p\)。
很自然的得到了一棵线段树.jpg
struct smt {
int sum[maxn << 2], ans[maxn << 2], tag[maxn << 2];
#define ls (p << 1)
#define rs (p << 1 | 1)
void up(int p) {
sum[p] = add(sum[ls], sum[rs]);
ans[p] = add(ans[ls], ans[rs]);
}
void down(int p, int v, int l, int r) {
ans[p] = add(add(ans[p], mul(2 * v, sum[p])), mul(mul(v, v), r - l + 1));
sum[p] = add(sum[p], mul(v, r - l + 1));
tag[p] = add(tag[p], v);
}
void down(int p, int l, int r) {
if(!tag[p]) return;
int mid = l + r >> 1;
down(ls, tag[p], l, mid);
down(rs, tag[p], mid + 1, r);
tag[p] = 0;
}
void upd(int p, int ql, int qr, int l, int r, int v) {
if(ql <= l && r <= qr) { down(p, v, l, r); return ; }
down(p, l, r);
int mid = l + r >> 1;
if(ql <= mid) upd(ls, ql, qr, l, mid, v);
if(qr > mid) upd(rs, ql, qr, mid + 1, r, v);
up(p);
}
} smt;
然后就做完了啊?
// by Isaunoya
#include<bits/stdc++.h>
using ll = long long;
using namespace std;
struct io {
char buf[1 << 26 | 3], *s;
int f;
io() {
f = 0, buf[fread(s = buf, 1, 1 << 26, stdin)] = '\n';
}
io& operator >> (int&x) {
for(x = f = 0; !isdigit(*s); ++s) f |= *s == '-';
while(isdigit(*s)) x = x * 10 + (*s++ ^ 48);
return x = f ? -x : x, *this;
}
};
const int maxn = 1e6 + 61;
int n;
int a[maxn], b[maxn];
const int mod = 1e9 + 7;
int add(int x, int y) {
if(x + y >= mod) return (x + y - mod);
return x + y;
}
ll mul(int x, int y) {
ll ans = 1ll * x * y;
if(ans >= mod) return ans % mod;
return ans;
}
struct smt {
int sum[maxn << 2], ans[maxn << 2], tag[maxn << 2];
#define ls (p << 1)
#define rs (p << 1 | 1)
void up(int p) {
sum[p] = add(sum[ls], sum[rs]);
ans[p] = add(ans[ls], ans[rs]);
}
void down(int p, int v, int l, int r) {
ans[p] = add(add(ans[p], mul(2 * v, sum[p])), mul(mul(v, v), r - l + 1));
sum[p] = add(sum[p], mul(v, r - l + 1));
tag[p] = add(tag[p], v);
}
void down(int p, int l, int r) {
if(!tag[p]) return;
int mid = l + r >> 1;
down(ls, tag[p], l, mid);
down(rs, tag[p], mid + 1, r);
tag[p] = 0;
}
void upd(int p, int ql, int qr, int l, int r, int v) {
if(ql <= l && r <= qr) { down(p, v, l, r); return ; }
down(p, l, r);
int mid = l + r >> 1;
if(ql <= mid) upd(ls, ql, qr, l, mid, v);
if(qr > mid) upd(rs, ql, qr, mid + 1, r, v);
up(p);
}
} smt;
signed main() {
#ifdef LOCAL
freopen("testdata.in", "r", stdin);
#endif
io in;
in >> n;
for(int i = 1 ; i <= n ; i ++) in >> a[i];
for(int i = 1 ; i <= n ; i ++) b[i] = a[i];
sort(b + 1, b + n + 1); int len = unique(b + 1 , b + n + 1) - b - 1;
for(int i = 1 ; i <= n ; i ++) a[i] = lower_bound(b + 1, b + len + 1, a[i]) - b;
static int las[maxn], pre[maxn];
for(int i = 1 ; i <= n ; i ++) pre[i] = las[a[i]], las[a[i]] = i;
int ans = 0;
for(int i = 1 ; i <= n ; i ++) {
smt.upd(1, pre[i] + 1, i, 1, n, 1);
ans = add(ans, smt.ans[1]);
}
cout << ans << '\n';
return 0;
}