[NOI Online #2 提高组]子序列问题

没啥意思的板子题,HH的项链既视感…

反正我没打这个 NOIOL,随便写下好了,没有心路历程

别问为啥没打,问就是周六还呆在学校上文化课

首先这个显然是 HH的项链,如果你把 \(f(l,r)^2\) 变成 \(f(l,r)\) 就是 HH的项链变成区间求和…如果是 \(f(l,r)^2\) 就维护区间平方和的板子。

我们考虑所有以 \(r\) 为右端点的所有 \(\sum f_{i,r}\),假设 \(f_{i,r-1}\) 求的是对的,那么你 \(r\) 的贡献区间是 \([pre_r + 1, r]\),然后 HH的项链的套路这题就做完了啊

因为
\((x+y)^2 = x^2 + y^2 + 2xy\)

所以
只需要记录区间和的 \(sum_p\),以及区间平方和的 \(ans_p\)

很自然的得到了一棵线段树.jpg

struct smt {
	int sum[maxn << 2], ans[maxn << 2], tag[maxn << 2];
#define ls (p << 1)
#define rs (p << 1 | 1)
	void up(int p) {
		sum[p] = add(sum[ls], sum[rs]);
		ans[p] = add(ans[ls], ans[rs]);
	}
	
	void down(int p, int v, int l, int r) {
		ans[p] = add(add(ans[p], mul(2 * v, sum[p])), mul(mul(v, v), r - l + 1));
		sum[p] = add(sum[p], mul(v, r - l + 1));
		tag[p] = add(tag[p], v);
	}
	
	void down(int p, int l, int r) {
		if(!tag[p]) return;
		int mid = l + r >> 1;
		down(ls, tag[p], l, mid);
		down(rs, tag[p], mid + 1, r);
		tag[p] = 0;
	}
	
	void upd(int p, int ql, int qr, int l, int r, int v) {
		if(ql <= l && r <= qr) { down(p, v, l, r); return ; }
		down(p, l, r);
		int mid = l + r >> 1;
		if(ql <= mid) upd(ls, ql, qr, l, mid, v);
		if(qr > mid) upd(rs, ql, qr, mid + 1, r, v);
		up(p);
	}
} smt;

然后就做完了啊?

// by Isaunoya
#include<bits/stdc++.h>
using ll = long long;
using namespace std;
struct io {
	char buf[1 << 26 | 3], *s;
	int f;
	io() {
		f = 0, buf[fread(s = buf, 1, 1 << 26, stdin)] = '\n';
	}
	io& operator >> (int&x) {
		for(x = f = 0; !isdigit(*s); ++s) f |= *s  == '-';
		while(isdigit(*s)) x = x * 10 + (*s++ ^ 48);
		return x = f ? -x : x, *this;
	}
};

const int maxn = 1e6 + 61;
int n;
int a[maxn], b[maxn];
const int mod = 1e9 + 7;

int add(int x, int y) {
	if(x + y >= mod) return (x + y - mod);
	return x + y;
}
ll mul(int x, int y) {
	ll ans = 1ll * x * y;
	if(ans >= mod) return ans % mod;
	return ans;
}

struct smt {
	int sum[maxn << 2], ans[maxn << 2], tag[maxn << 2];
#define ls (p << 1)
#define rs (p << 1 | 1)
	void up(int p) {
		sum[p] = add(sum[ls], sum[rs]);
		ans[p] = add(ans[ls], ans[rs]);
	}
	
	void down(int p, int v, int l, int r) {
		ans[p] = add(add(ans[p], mul(2 * v, sum[p])), mul(mul(v, v), r - l + 1));
		sum[p] = add(sum[p], mul(v, r - l + 1));
		tag[p] = add(tag[p], v);
	}
	
	void down(int p, int l, int r) {
		if(!tag[p]) return;
		int mid = l + r >> 1;
		down(ls, tag[p], l, mid);
		down(rs, tag[p], mid + 1, r);
		tag[p] = 0;
	}
	
	void upd(int p, int ql, int qr, int l, int r, int v) {
		if(ql <= l && r <= qr) { down(p, v, l, r); return ; }
		down(p, l, r);
		int mid = l + r >> 1;
		if(ql <= mid) upd(ls, ql, qr, l, mid, v);
		if(qr > mid) upd(rs, ql, qr, mid + 1, r, v);
		up(p);
	}
} smt;

signed main() {
#ifdef LOCAL
	freopen("testdata.in", "r", stdin);
#endif
	io in;
	in >> n;
	for(int i = 1 ; i <= n ; i ++) in >> a[i];
	for(int i = 1 ; i <= n ; i ++) b[i] = a[i];
	sort(b + 1, b + n + 1); int len = unique(b + 1 , b + n + 1) - b - 1;
	for(int i = 1 ; i <= n ; i ++) a[i] = lower_bound(b + 1, b + len + 1, a[i])  - b;
	static int las[maxn], pre[maxn];
	for(int i = 1 ; i <= n ; i ++) pre[i] = las[a[i]], las[a[i]] = i;
	int ans  = 0;
	for(int i = 1 ; i <= n ; i ++) {
		smt.upd(1, pre[i] + 1, i, 1, n, 1);
		ans = add(ans, smt.ans[1]);
	}
	cout << ans << '\n';
	return 0;
}
posted @ 2020-04-26 12:52  _Isaunoya  阅读(214)  评论(0编辑  收藏  举报