do_while_true

一言(ヒトコト)

「题解」Codeforces 1553F Pairwise Modulo

很多人是根号分治,来一发 polylog 的做法。

线段树/树状数组好题啊。

因为懒下面复杂度均视作 \(a=n^{\mathcal{O}(1)}\)

首先考虑求出 \(p\) 的差分数组,也就是求出:

\[p'_k=\sum_{1\leq i\leq n} a_i\bmod a_j+a_j\bmod a_i \]

分类讨论一下:

  1. \(a_i \bmod a_j,\ a_i<a_j\)
  2. \(a_j \bmod a_i,\ a_i>a_j\)

直接树状数组,前者贡献均为 \(a_i\),记录一下前面大于 \(a_i\) 的有多少个。后者贡献均为 \(a_j\),记录一下前面小于 \(a_i\)\(a_j\) 的和。

  1. \(a_i \bmod a_j,\ a_i>a_j\)

维护一个 \(cost_i\),代表考虑到当前位,前面的数对 \(i\) 的贡献是多少。

发现每次是在 \(a_i\) 的倍数中间的一段都会有个等差数列加上去,直接用线段树维护差分数组一段一段暴力改,由于 \(a_i\) 互不相同,所以等差数列个数是调和级数 \(\mathcal{O}(n\log n)\) 的,总复杂度是 \(\mathcal{O}(n\log^2 n)\) 的。

  1. \(a_j \mod a_i, a_i<a_j\)

考虑枚举 \(a_i\) 的倍数分出的一个个段,统计这个段里 \(a_j\)\(a_i\) 的贡献。

那么答案就是

\[\begin{aligned} &\sum_{l\leq i< r}(i-l)\cdot a_i \\ =&\sum_{l\leq i< r}i\cdot a_i-l\cdot a_i \end{aligned} \]

只需要维护 \(i\cdot a_i\)\(a_i\) 在值域维的区间和即可。枚举 \(a_i\) 倍数的次数是调和级数 \(\mathcal{O}(n\log n)\) 的,用树状数组维护,总复杂度是 \(\mathcal{O}(n\log^2n)\) 的。

\(\mathcal{Code}\)

#include<iostream>
#include<cstdio>
#include<vector>
#include<algorithm>
typedef long long ll;
template <typename T> T Max(T x, T y) { return x > y ? x : y; }
template <typename T> T Min(T x, T y) { return x < y ? x : y; }
template <typename T>
T& read(T& r) {
	r = 0; bool w = 0; char ch = getchar();
	while(ch < '0' || ch > '9') w = ch == '-' ? 1 : 0, ch = getchar();
	while(ch >= '0' && ch <= '9') r = r * 10 + (ch ^ 48), ch = getchar();
	return r = w ? -r : r;
}
inline int lowbit(int x) { return x & (-x); }
const int N = 300010;
int n, a[N];
ll p[N];
struct BIT {
	int m;
	ll tree[N];
	void init(int x) {
		m = x;
		for(int i = 0; i <= m; ++i) tree[i] = 0;
	}
	void modify(int x, ll v) { for(; x <= m; x += lowbit(x)) tree[x] += v; }
	ll querysum(int x) { ll sumq = 0; for(; x; x -= lowbit(x)) sumq += tree[x]; return sumq; }
	ll query(int x, int y) { return querysum(y)-querysum(x-1); }
}bit1, bit2, tr2, tr3;
struct segment_tree {
	#define ls tree[x].lson
	#define rs tree[x].rson
	#define tl tree[x].l
	#define tr tree[x].r
	int trnt;
	struct SGT {
		int l, r, lson, rson;
		ll sum, tag;
	}tree[N << 1];
	inline void pushup(int x) { tree[x].sum = tree[ls].sum + tree[rs].sum; }
	inline void pushdown(int x) {
		if(tree[x].tag) {
			int p = tree[x].tag;
			tree[ls].sum += (tree[ls].r - tree[ls].l + 1) * p;
			tree[rs].sum += (tree[rs].r - tree[rs].l + 1) * p;
			tree[ls].tag += p;
			tree[rs].tag += p;
			tree[x].tag = 0;
		}
	}
	int build(int l, int r) {
		int x = ++trnt; tree[x].sum = 0;
		tl = l; tr = r;
		if(l == r) return x;
		ls = build(l, (l+r)>>1); rs = build(tree[ls].r+1, r);
		pushup(x); return x;
	}
	void modify(int x, int l, int r, ll v) {
		if(tl >= l && tr <= r) {
			tree[x].sum += (tree[x].r - tree[x].l + 1) * v;
			tree[x].tag += v;
			return ;
		}
		int mid = (tree[x].l + tree[x].r) >> 1;
		pushdown(x);
		if(mid >= l) modify(ls, l, r, v);
		if(mid < r) modify(rs, l, r , v);
		pushup(x);
	}
	ll query(int x, int l, int r) {
		if(tl >= l && tr <= r) return tree[x].sum;
		int mid = (tree[x].l + tree[x].r) >> 1; ll sumq = 0;
		pushdown(x);
		if(mid >= l) sumq += query(ls, l, r);
		if(mid < r) sumq += query(rs, l, r);
		pushup(x);
		return sumq;
	}
}tr1;
#undef ls
#undef rs
#undef tl
#undef tr
int mx = 0;
signed main() {
	read(n);
	for(int i = 1; i <= n; ++i) read(a[i]), mx = Max(mx, a[i]);
	bit1.init(mx); bit2.init(mx); tr2.init(mx); tr3.init(mx);
	for(int i = 1; i <= n; ++i) {
		p[i] += 1ll * a[i] * bit1.query(a[i]+1, mx);
		p[i] += bit2.query(1, a[i]-1);
		bit1.modify(a[i], 1);
		bit2.modify(a[i], a[i]);
	}
	tr1.build(1, mx);
	for(int i = 1; i <= n; ++i) {
		p[i] += tr1.query(1, 1, a[i]);
		for(int j = 1; j * a[i] <= mx; ++j) {
			p[i] -= j*a[i] * tr2.query(j*a[i], Min(mx, (j+1)*a[i]-1));
			p[i] += tr3.query(j*a[i], Min(mx, (j+1)*a[i]-1));
		}
		for(int j = 1; j * a[i] + 1 <= mx; ++j) {
			tr1.modify(1, j*a[i]+1, Min(mx, (j+1)*a[i]-1), 1);
			if((j+1)*a[i] <= mx)tr1.modify(1, (j+1)*a[i], (j+1)*a[i], -(Min(mx, (j+1)*a[i]-1) - (j*a[i]+1) + 1));
		}
		if(a[i] == 1) continue ;
		tr2.modify(a[i], 1);
		tr3.modify(a[i], a[i]);
	}
	for(int i = 1; i <= n; ++i) p[i] += p[i-1], printf("%lld ", p[i]);
	return 0;
}
posted @ 2021-07-23 10:30  do_while_true  阅读(66)  评论(0编辑  收藏  举报