luogu P5072 [Ynoi2015] 盼君勿忘
https://www.luogu.com.cn/problem/P5072
先考虑没有p怎么做?
显然可以容斥,考虑用总的贡献减去不合法的贡献
总贡献就是这个区间里的所有不重复的数之和
×
2
l
e
n
\times 2^{len}
×2len
不合法的就是这个数没有被选进子序列里的个数
考虑莫队,显然容易实现,但是不同的数可能有很多个,时间复杂度会暴毙
考虑根号分治,出现次数
>
n
>\sqrt{n}
>n的单独拿出来维护,小于的维护出现次数为
x
x
x的不同的数字之和,然后就可以做了
现在考虑有p怎么做,主要是中间要
×
2
i
m
o
d
p
\times 2^i \mod p
×2imodp
同样可以采用根号预处理
把它理解为
2
n
进
制
数
2^{\sqrt{n}}进制数
2n进制数
2
i
=
2
a
×
n
×
2
b
2^i=2^{a\times\sqrt n}\times2^b
2i=2a×n×2b
于是乎总时间复杂度就是一个根号的
code:
#include<bits/stdc++.h>
#define ll long long
#define N 200050
using namespace std;
int n, m, bel[N], cnt[N], gs[N], mod, c[N], sz, blo, a[N];
ll sum[N], pw1[N], pw2[N], ans[N];
ll poww(int x) {
return 1ll * pw2[x / blo] * pw1[x % blo] % mod;
}
struct Q {
int l, r, p, id;
} q[N];
int cmp(Q x, Q y) {
if(bel[x.l] == bel[y.l]) return (bel[x.l] & 1)? x.r < y.r : x.r > y.r;
return x.l < y.l;
}
void add(int x) {
if(cnt[x] > blo) gs[x] ++;
else sum[gs[x]] -= x, sum[++ gs[x]] += x;
}
void del(int x) {
if(cnt[x] > blo) gs[x] --;
else sum[gs[x]] -= x, sum[-- gs[x]] += x;
}
int main() {
scanf("%d%d", &n, &m); blo = sqrt(n) + 1;
for(int i = 1; i <= n; i ++) {
scanf("%d", &a[i]);
cnt[a[i]] ++;
bel[i] = (i - 1) / blo + 1;
}
for(int i = 1; i <= 100000; i ++) if(cnt[i] > blo) c[++ sz] = i;
for(int i = 1; i <= m; i ++) scanf("%d%d%d", &q[i].l, &q[i].r, &q[i].p), q[i].id = i;
sort(q + 1, q + 1 + m, cmp);
int l = 1, r = 0;
for(int i = 1; i <= m; i ++) {
for(; q[i].l < l ;) add(a[-- l]);
for(; q[i].r > r ;) add(a[++ r]);
for(; q[i].l > l ;) del(a[l ++]);
for(; q[i].r < r ;) del(a[r --]);
mod = q[i].p;
pw1[0] = pw2[0] = 1;
for(int j = 1; j <= blo; j ++) pw1[j] = pw1[j - 1] * 2ll % mod;
for(int j = 1; j <= blo; j ++) pw2[j] = 1ll * pw2[j - 1] * pw1[blo] % mod;
ll len = r - l + 1, minus = 0, plus = 0;
for(int j = 1; j <= sz; j ++) {
int x = c[j];
minus = (minus + 1ll * x * poww(len - gs[x])) % mod;
plus += x;
}
for(int j = 1; j <= blo; j ++) {
minus = (minus + 1ll * sum[j] * poww(len - j)) % mod;
plus += sum[j];
}
// for(int j = 1; j <= blo; j ++) printf(" %lld ", pw1[j]); printf("\n");
// for(int j = 1; j <= blo; j ++) printf(" %lld ", pw2[j]); printf(" %d\n", blo);
// printf("%d %lld\n", mod, poww(5));
plus %= mod;
ans[q[i].id] = (plus * poww(len) % mod - minus + mod) % mod;
}
for(int i = 1; i <= m; i ++) printf("%lld\n", ans[i]);
return 0;
}