有关主席树的一些 trick
主席树做题记录。
主席树,即可持久化权值线段树。
P3248 [HNOI2016] 树
难爆了这题。题目中会多次把模板树的某个子树放到大树上的某个节点下,我们把这一整个子树看作一个大节点,把模板树、大树分别维护。
具体的,模板树上需要倍增维护两点之间的距离,dfs
序。
大树上需要维护:
-
大树上大节点的节点编号的上下界;
-
大节点是挂在大树的哪个节点下的;
-
该大节点是模板树中以哪个节点为根的子树;
-
倍增维护大节点间的距离。
另外我们需要维护大树上的几个操作,然后在大树上倍增求出答案。
-
大树上某个节点属于哪个大节点,可以根据大节点编号上下界二分求得;
-
大树上某个节点对应的是模板树上哪个节点,这个实际上是求第 \(k\) 小,可以用主席树实现;
接下来深度剖析在大树上倍增求解答案的过程:
设 \(u\) 是深度大的节点,\(v\) 是深度小的节点。首先需要求大树上的节点 \(u\) 跳到自己所在的大节点的根上的距离,然后再将 \(u\) 和 \(v\) 跳到同一深度,加上大节点与大节点间的距离。如果此时 \(u,v\) 在同一大节点中,则直接大节点内求答案,否则将 \(v\) 跳到自己所在的大节点的根上,然后 \(u,v\) 一起倍增往上跳,最后直接求 \(u,v\) 所在的链的大节点将要跳到他们的 \(lca\)(大节点)上是要经过挂在大树的哪个节点下的之间的距离。
维护好细节就行了。
点击查看代码
#include <bits/stdc++.h>
// #define int long long
#define ll long long
#define ull unsigned long long
#define db double
#define ld long double
#define rep(i,l,r) for (int i = (int)(l); i <= (int)(r); ++ i )
#define rep1(i,l,r) for (int i = (int)(l); i >= (int)(r); -- i )
#define il inline
#define fst first
#define snd second
#define ptc putchar
#define Yes ptc('Y'),ptc('e'),ptc('s'),puts("")
#define No ptc('N'),ptc('o'),puts("")
#define YES ptc('Y'),ptc('E'),ptc('S'),puts("")
#define NO ptc('N'),ptc('O'),puts("")
#define pb emplace_back
#define sz(x) (int)(x.size())
#define all(x) x.begin(),x.end()
#define debug() puts("------------------")
using namespace std;
typedef pair<int,int> PII;
typedef pair<int,PII> PIII;
typedef pair<ll,ll> PLL;
namespace szhqwq {
template<class T> il void read(T &x) {
x = 0; T f = 1; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
x *= f;
}
il int read() { int x; read(x); return x; }
il ll readl() { ll x; read(x); return x; }
template<class T,class... Args> il void read(T &x,Args &...x_) { read(x); read(x_...); }
template<class T> il void print(T x) {
if (x < 0) ptc('-'), x = -x;
if (x > 9) print(x / 10); ptc(x % 10 + '0');
}
template<class T,class T_> il void write(T x,T_ ch) { print(x); ptc(ch); }
template<class T,class T_> il void chmax(T &x,T_ y) { x = max(x,y); }
template<class T,class T_> il void chmin(T &x,T_ y) { x = min(x,y); }
template<class T,class T_,class T__> il T qmi(T a,T_ b,T__ p) {
T res = 1; while (b) {
if (b & 1) res = res * a % p;
a = a * a % p; b >>= 1;
} return res;
}
template<class T> il int getinv(T x,T p) { return qmi(x,p - 2,p); }
} using namespace szhqwq;
bool stpos;
const int N = 1e5 + 10,M = 20,inf = 1e9,mod = 998244353;
const ll inff = 1e18;
int n,m,q;
int h[N],e[N << 1],ne[N << 1],idx;
PLL cp[N];
struct kvsgt {
int rt[N],tot;
struct node { int l,r,val; } tr[N << 5];
il void pushup(int u) {
tr[u].val = tr[tr[u].l].val + tr[tr[u].r].val;
}
il void modify(int &u,int p,int l,int r,int x) {
u = ++ tot;
tr[u] = tr[p];
if (l == r) {
++ tr[u].val;
return ;
}
int mid = l + r >> 1;
if (x <= mid) modify(tr[u].l,tr[p].l,l,mid,x);
else modify(tr[u].r,tr[p].r,mid + 1,r,x);
pushup(u);
}
il int query(int u,int p,int l,int r,int k) {
if (l == r) return l;
int mid = l + r >> 1;
int cntl = tr[tr[u].l].val - tr[tr[p].l].val;
if (k <= cntl) return query(tr[u].l,tr[p].l,l,mid,k);
return query(tr[u].r,tr[p].r,mid + 1,r,k - cntl);
}
} T;
struct basic_tree {
int d[N],dfn[N],cnt,id[N],in[N],out[N],fa[N][M];
il void dfs(int u,int f) {
fa[u][0] = f;
rep(i,1,17) fa[u][i] = fa[fa[u][i - 1]][i - 1];
d[u] = d[f] + 1;
dfn[++ cnt] = u;
in[u] = cnt;
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == f) continue;
dfs(j,u);
}
out[u] = cnt;
}
il int lca(int x,int y) {
if (d[x] < d[y]) swap(x,y);
rep1(i,17,0) if (d[fa[x][i]] >= d[y]) x = fa[x][i];
if (x == y) return x;
rep1(i,17,0) if (fa[x][i] != fa[y][i]) x = fa[x][i],y = fa[y][i];
return fa[x][0];
}
il int getdis(int x,int y) {
return d[x] + d[y] - 2 * d[lca(x,y)];
}
} BT;
struct big_tree {
int u,d[N]; ll L[N],R[N],gua[N],zi[N],fa[N][M],dist[N][M];
il int getbig(ll x) {
int l = 1,r = u,res = 0;
while (l <= r) {
int mid = l + r >> 1;
if (L[mid] > x) r = mid - 1;
else l = mid + 1,res = mid;
}
return res;
}
il int getkth(ll x) {
int xxx = getbig(x);
int k = T.query(T.rt[BT.out[zi[xxx]]],T.rt[BT.in[zi[xxx]] - 1],1,n,x - L[xxx] + 1);
return k;
}
il void init() {
L[++ u] = 1; R[u] = n; zi[u] = 1; d[u] = 1;
// cerr << " " << L[u] << " " << R[u] << endl;
ll now = n + 1;
rep(i,1,m) {
int x = cp[i].fst; ll to = cp[i].snd; int f = getbig(to);
L[++ u] = now; R[u] = now + BT.out[x] - BT.in[x]; now = R[u] + 1ll;
gua[u] = to; zi[u] = x; fa[u][0] = f;
dist[u][0] = 1ll + BT.d[getkth(to)] - BT.d[zi[f]];
// cerr << u << " " << f << endl;
d[u] = d[f] + 1;
rep(j,1,17) {
fa[u][j] = fa[fa[u][j - 1]][j - 1];
dist[u][j] = dist[fa[u][j - 1]][j - 1] + dist[u][j - 1];
}
}
// debug();
// rep(i,1,m + 1) {
// rep(j,0,3) cout << fa[i][j] << " ";
// cout << endl;
// }
// debug();
}
il ll getans(ll x,ll y) {
ll res = 0;
int xx = getbig(x),yy = getbig(y);
// cout << "aaa " << x << " " << y << " " << xx << " " << yy << endl;
if (xx == yy) return BT.getdis(getkth(x),getkth(y));
if (d[xx] < d[yy]) swap(x,y),swap(xx,yy);
res += BT.d[getkth(x)] - BT.d[zi[xx]]; // 大节点中小节点跳到大节点根上面去
// cout << "!!! " << res << " " << getkth(x) << " " << zi[xx] << endl;
rep1(i,17,0) // 倍增跳深度大的节点
if (d[fa[xx][i]] > d[yy]) {
res += dist[xx][i];
// cout << "??? " << i << " " << d[fa[xx][i]] << " " << fa[xx][i] << " " << dist[xx][i] << endl;
xx = fa[xx][i];
}
if (getbig(gua[xx]) == yy) {
res += 1ll + 1ll * BT.getdis(getkth(gua[xx]),getkth(y)); // 特殊情况
return res;
}
if (d[xx] > d[yy]) res += dist[xx][0],xx = fa[xx][0];
// cout << "??? " << res << endl;
res += BT.d[getkth(y)] - BT.d[zi[yy]];
// if (xx == yy) return res; // 同一个大节点内小节点间距离,此时 x 已经在根上了,所以即 y 到根的距离
rep1(i,17,0) // 倍增同时跳
if (fa[xx][i] != fa[yy][i]) {
res += dist[xx][i] + dist[yy][i];
xx = fa[xx][i]; yy = fa[yy][i];
}
res += 2ll + 1ll * BT.getdis(getkth(gua[xx]),getkth(gua[yy])); // 是这样的
return res;
}
} BGT;
il void add(int a,int b) {
e[idx] = b;
ne[idx] = h[a];
h[a] = idx ++;
}
il void solve() {
//------------code------------
// clock_t start = clock();
memset(h,-1,sizeof h);
read(n,m,q);
rep(i,1,n - 1) {
int a,b; read(a,b);
add(a,b); add(b,a);
}
rep(i,1,m) cp[i] = {readl(),readl()};
BT.dfs(1,0);
rep(i,1,n) T.modify(T.rt[i],T.rt[i - 1],1,n,BT.dfn[i]);
BGT.init();
while (q -- ) write(BGT.getans(readl(),readl()),'\n');
// clock_t end = clock();
// cerr << "Time : " << (db)(end - start) << " ms" << endl;
// bool endpos;
// cerr << (&endpos - &stpos) / 1024 / 1024 << endl;
return ;
}
il void init() {
return ;
}
signed main() {
// init();
int _ = 1;
// read(_);
while (_ -- ) solve();
return 0;
}
P3293 [SCOI2016] 美味
好题。
每次的 \(b,x\) 是固定的,相当于要找到一个最优的 \(a_i + x\) 使其与 \(b\) 异或值最大。我们贪心地从高位到低位枚举,并记录当前的最优的 \(a_i + x\),对于 \(b\) 来说,如果这一位为 1
则只要存在 \(a_i \in [...000000...,...011111...]\) 那么 \(b \oplus (a_i + x)\) 在这一位为 1
,即 \(a_i \in [now - x,now - x + 2^i - 1]\)。
这一位为 0
同理,\(a_i \in [now - x + 2^i,now - x + 2^{i + 1} - 1]\)。
主席树每次看有没有这个区间的满足条件的 \(a_i\) 即可。
点击查看代码
#include <bits/stdc++.h>
// #define int long long
#define ll long long
#define ull unsigned long long
#define db double
#define ld long double
#define rep(i,l,r) for (int i = (int)(l); i <= (int)(r); ++ i )
#define rep1(i,l,r) for (int i = (int)(l); i >= (int)(r); -- i )
#define il inline
#define fst first
#define snd second
#define ptc putchar
#define Yes ptc('Y'),ptc('e'),ptc('s'),puts("")
#define No ptc('N'),ptc('o'),puts("")
#define YES ptc('Y'),ptc('E'),ptc('S'),puts("")
#define NO ptc('N'),ptc('O'),puts("")
#define pb emplace_back
#define sz(x) (int)(x.size())
#define all(x) x.begin(),x.end()
#define get(x) ((x - 1) / len + 1)
#define debug() puts("------------------")
using namespace std;
typedef pair<int,int> PII;
typedef pair<int,PII> PIII;
typedef pair<ll,ll> PLL;
namespace szhqwq {
template<class T> il void read(T &x) {
x = 0; T f = 1; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
x *= f;
}
il int read() { int x; read(x); return x; }
template<class T,class... Args> il void read(T &x,Args &...x_) { read(x); read(x_...); }
template<class T> il void print(T x) {
if (x < 0) ptc('-'), x = -x;
if (x > 9) print(x / 10); ptc(x % 10 + '0');
}
template<class T,class T_> il void write(T x,T_ ch) { print(x); ptc(ch); }
template<class T,class T_> il void chmax(T &x,T_ y) { x = max(x,y); }
template<class T,class T_> il void chmin(T &x,T_ y) { x = min(x,y); }
template<class T,class T_,class T__> il T qmi(T a,T_ b,T__ p) {
T res = 1; while (b) {
if (b & 1) res = res * a % p;
a = a * a % p; b >>= 1;
} return res;
}
template<class T> il int getinv(T x,T p) { return qmi(x,p - 2,p); }
} using namespace szhqwq;
const int N = 2e5 + 10,inf = 1e9,mod = 998244353;
const ll inff = 1e18;
int n,m,a[N],rt[N],tot,VAL = 300000;
struct node { int l,r,val; } tr[N << 5];
il void pushup(int u) {
tr[u].val = tr[tr[u].l].val + tr[tr[u].r].val;
}
il void modify(int &u,int p,int l,int r,int x) {
u = ++ tot; tr[u] = tr[p];
if (l == r) {
++ tr[u].val;
return ;
}
int mid = l + r >> 1;
if (x <= mid) modify(tr[u].l,tr[p].l,l,mid,x);
else modify(tr[u].r,tr[p].r,mid + 1,r,x);
pushup(u);
}
il int query(int u,int p,int l,int r,int L,int R) {
if (L <= l && r <= R) return tr[u].val - tr[p].val;
int ret = 0;
int mid = l + r >> 1;
if (L <= mid) ret += query(tr[u].l,tr[p].l,l,mid,L,R);
if (mid < R) ret += query(tr[u].r,tr[p].r,mid + 1,r,L,R);
return ret;
}
il void solve() {
//------------code------------
read(n,m);
rep(i,1,n) {
read(a[i]);
modify(rt[i],rt[i - 1],0,VAL,a[i]);
}
while (m -- ) {
int b,x,l,r; read(b,x,l,r);
int ret = 0;
rep1(i,17,0) {
if (b >> i & 1) {
int val = query(rt[r],rt[l - 1],0,VAL,max(0,ret - x),min(VAL,ret - x + (1 << i) - 1));
if (!val) ret += (1 << i);
} else {
int val = query(rt[r],rt[l - 1],0,VAL,max(0,ret - x + (1 << i)),min(VAL,ret - x + (1 << (i + 1)) - 1));
if (val) ret += (1 << i);
}
}
write(ret ^ b,'\n');
}
// cerr << "Time : " << (db)(end - start) / CLOCKS_PER_SEC << " s" << endl;
return ;
}
il void init() {
return ;
}
signed main() {
// init();
int _ = 1;
// read(_);
while (_ -- ) solve();
return 0;
}
P3168 [CQOI2015] 任务查询系统
板子。
以时间为根把优先级放到主席树上去查询就行。差分一下就可以了。
点击查看代码
#include <bits/stdc++.h>
// #define int long long
#define ll long long
#define ull unsigned long long
#define db double
#define ld long double
#define rep(i,l,r) for (int i = (int)(l); i <= (int)(r); ++ i )
#define rep1(i,l,r) for (int i = (int)(l); i >= (int)(r); -- i )
#define il inline
#define fst first
#define snd second
#define ptc putchar
#define Yes ptc('Y'),ptc('e'),ptc('s'),puts("")
#define No ptc('N'),ptc('o'),puts("")
#define YES ptc('Y'),ptc('E'),ptc('S'),puts("")
#define NO ptc('N'),ptc('O'),puts("")
#define pb emplace_back
#define sz(x) (int)(x.size())
#define all(x) x.begin(),x.end()
#define get(x) ((x - 1) / len + 1)
#define debug() puts("------------------")
using namespace std;
typedef pair<int,int> PII;
typedef pair<int,PII> PIII;
typedef pair<ll,ll> PLL;
namespace szhqwq {
template<class T> il void read(T &x) {
x = 0; T f = 1; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
x *= f;
}
il int read() { int x; read(x); return x; }
template<class T,class... Args> il void read(T &x,Args &...x_) { read(x); read(x_...); }
template<class T> il void print(T x) {
if (x < 0) ptc('-'), x = -x;
if (x > 9) print(x / 10); ptc(x % 10 + '0');
}
template<class T,class T_> il void write(T x,T_ ch) { print(x); ptc(ch); }
template<class T,class T_> il void chmax(T &x,T_ y) { x = max(x,y); }
template<class T,class T_> il void chmin(T &x,T_ y) { x = min(x,y); }
template<class T,class T_,class T__> il T qmi(T a,T_ b,T__ p) {
T res = 1; while (b) {
if (b & 1) res = res * a % p;
a = a * a % p; b >>= 1;
} return res;
}
template<class T> il int getinv(T x,T p) { return qmi(x,p - 2,p); }
} using namespace szhqwq;
const int N = 1e5 + 10,inf = 1e9,mod = 998244353;
const ll inff = 1e18;
int n,m,rt[N],tot,VAL = 2e7;
struct node { int l,r,c; ll val; } tr[N << 6];
vector<int> v; vector<PII> upd[N];
il void pushup(int u) {
tr[u].c = tr[tr[u].l].c + tr[tr[u].r].c;
tr[u].val = tr[tr[u].l].val + tr[tr[u].r].val;
}
il void modify(int &u,int p,int l,int r,int x,int k) {
u = ++ tot; tr[u] = tr[p];
if (l == r) {
tr[u].c += k;
tr[u].val += 1ll * x * k;
return ;
}
int mid = l + r >> 1;
if (x <= mid) modify(tr[u].l,tr[p].l,l,mid,x,k);
else modify(tr[u].r,tr[p].r,mid + 1,r,x,k);
pushup(u);
}
il ll query(int u,int l,int r,ll k) {
if (l == r) return tr[u].val / tr[u].c * k;
int mid = l + r >> 1;
if (k <= tr[tr[u].l].c) return query(tr[u].l,l,mid,k);
return tr[tr[u].l].val + query(tr[u].r,mid + 1,r,k - tr[tr[u].l].c);
}
il void solve() {
//------------code------------
read(m,n);
rep(i,1,m) {
int s,e,p; read(s,e,p);
upd[s].pb(p,1); upd[e + 1].pb(p,-1);
v.pb(p);
}
sort(all(v)); v.erase(unique(all(v)),v.end());
// VAL = sz(v);
rep(i,1,n) {
rt[i] = rt[i - 1];
for (auto x : upd[i]) {
// cout << x.fst << " " << x.snd << endl;
// x.fst = lower_bound(all(v),x.fst) - v.begin();
// cerr << i << " " << x.fst << endl;
modify(rt[i],rt[i],0,VAL,x.fst,x.snd);
}
}
ll pre = 1;
rep(i,1,n) {
int x,a,b,c; read(x,a,b,c);
ll k = 1ll + 1ll * (a * pre + b) % c;
// cerr << k << endl;
if (k > tr[rt[x]].c) write(pre = tr[rt[x]].val,'\n');
else write(pre = query(rt[x],0,VAL,k),'\n');
}
// cerr << "Time : " << (db)(end - start) / CLOCKS_PER_SEC << " s" << endl;
return ;
}
il void init() {
return ;
}
signed main() {
// init();
int _ = 1;
// read(_);
while (_ -- ) solve();
return 0;
}
CF813E Army Creation
妙妙题。一个好的 trick。
出现次数不好限制,但是我们可以把当前这个数之前出现第 \(k\) 次的位置(不包括当前这个位置)放到主席树上面去,不足 \(k\) 次的放 0
。每次询问直接询问 \([0,l - 1]\) 有多少个数即可。
点击查看代码
#include <bits/stdc++.h>
// #define int long long
#define ll long long
#define ull unsigned long long
#define db double
#define ld long double
#define rep(i,l,r) for (int i = (int)(l); i <= (int)(r); ++ i )
#define rep1(i,l,r) for (int i = (int)(l); i >= (int)(r); -- i )
#define il inline
#define fst first
#define snd second
#define ptc putchar
#define Yes ptc('Y'),ptc('e'),ptc('s'),puts("")
#define No ptc('N'),ptc('o'),puts("")
#define YES ptc('Y'),ptc('E'),ptc('S'),puts("")
#define NO ptc('N'),ptc('O'),puts("")
#define pb emplace_back
#define sz(x) (int)(x.size())
#define all(x) x.begin(),x.end()
#define get(x) ((x - 1) / len + 1)
#define debug() puts("------------------")
using namespace std;
typedef pair<int,int> PII;
typedef pair<int,PII> PIII;
typedef pair<ll,ll> PLL;
namespace szhqwq {
template<class T> il void read(T &x) {
x = 0; T f = 1; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
x *= f;
}
il int read() { int x; read(x); return x; }
template<class T,class... Args> il void read(T &x,Args &...x_) { read(x); read(x_...); }
template<class T> il void print(T x) {
if (x < 0) ptc('-'), x = -x;
if (x > 9) print(x / 10); ptc(x % 10 + '0');
}
template<class T,class T_> il void write(T x,T_ ch) { print(x); ptc(ch); }
template<class T,class T_> il void chmax(T &x,T_ y) { x = max(x,y); }
template<class T,class T_> il void chmin(T &x,T_ y) { x = min(x,y); }
template<class T,class T_,class T__> il T qmi(T a,T_ b,T__ p) {
T res = 1; while (b) {
if (b & 1) res = res * a % p;
a = a * a % p; b >>= 1;
} return res;
}
template<class T> il int getinv(T x,T p) { return qmi(x,p - 2,p); }
} using namespace szhqwq;
const int N = 1e5 + 10,inf = 1e9,mod = 998244353;
const ll inff = 1e18;
int n,k,q,a[N],rt[N],tot,VAL = 100000;
struct node { int l,r,val; } tr[N << 5];
vector<int> v[N];
il void pushup(int u) {
tr[u].val = tr[tr[u].l].val + tr[tr[u].r].val;
}
il void modify(int &u,int p,int l,int r,int x) {
u = ++ tot; tr[u] = tr[p];
if (l == r) return ++ tr[u].val,void();
int mid = l + r >> 1;
if (x <= mid) modify(tr[u].l,tr[p].l,l,mid,x);
else modify(tr[u].r,tr[p].r,mid + 1,r,x);
pushup(u);
}
il int query(int u,int p,int l,int r,int L,int R) {
if (L <= l && r <= R) return tr[u].val - tr[p].val;
int mid = l + r >> 1;
int ret = 0;
if (L <= mid) ret += query(tr[u].l,tr[p].l,l,mid,L,R);
if (mid < R) ret += query(tr[u].r,tr[p].r,mid + 1,r,L,R);
return ret;
}
il void solve() {
//------------code------------
read(n,k);
rep(i,1,n) {
read(a[i]); v[a[i]].pb(i);
// cerr << i << " " << (sz(v[a[i]]) <= k ? 0 : v[a[i]][sz(v[a[i]]) - k - 1]) << " " << sz(v[a[i]]) - k - 1 << " sdfsdfsdf\n";
modify(rt[i],rt[i - 1],0,VAL,sz(v[a[i]]) <= k ? 0 : v[a[i]][sz(v[a[i]]) - k - 1]);
}
read(q);
int pre = 0;
while (q -- ) {
int l,r,x,y; read(x,y);
l = ((x + pre) % n) + 1;
r = ((y + pre) % n) + 1;
if (l > r) swap(l,r);
// cerr << l << " " << r << endl;
write(pre = query(rt[r],rt[l - 1],0,VAL,0,l - 1),'\n');
}
// cerr << "Time : " << (db)(end - start) / CLOCKS_PER_SEC << " s" << endl;
return ;
}
il void init() {
return ;
}
signed main() {
// init();
int _ = 1;
// read(_);
while (_ -- ) solve();
return 0;
}
P4755 Beautiful Pair
妙妙题。
考虑转化题目给定的问题,可以固定最大值之后再去找不超过最大值的 \(a_i \times a_j\)。
考虑单调栈,对于 \(a_i\) 求出向左第一个 \(\geq a_i\) 的位置 \(le_i\) 以及向右第一个 \(> a_i\) 的位置 \(ri_i\),注意此处一方取等一方不取的原因是为了不重复计算,取等的方向反过来也是可行的。
考虑去枚举从 \(i\) 向左的区间和向右的区间中长度更小的区间,对于 \(a_j\) 我们只需要看另一边 \(a_k \leq \frac{a_i}{a_j}\) 中 \(k\) 的个数即可。这样均摊下来是 \(\log n\) 的。时间复杂度 \(\mathcal{O(n \log^2 n)}\)。
点击查看代码
#include <bits/stdc++.h>
// #define int long long
#define ll long long
#define ull unsigned long long
#define db double
#define ld long double
#define rep(i,l,r) for (int i = (int)(l); i <= (int)(r); ++ i )
#define rep1(i,l,r) for (int i = (int)(l); i >= (int)(r); -- i )
#define il inline
#define fst first
#define snd second
#define ptc putchar
#define Yes ptc('Y'),ptc('e'),ptc('s'),puts("")
#define No ptc('N'),ptc('o'),puts("")
#define YES ptc('Y'),ptc('E'),ptc('S'),puts("")
#define NO ptc('N'),ptc('O'),puts("")
#define pb emplace_back
#define sz(x) (int)(x.size())
#define all(x) x.begin(),x.end()
#define get(x) ((x - 1) / len + 1)
#define debug() puts("------------------")
using namespace std;
typedef pair<int,int> PII;
typedef pair<int,PII> PIII;
typedef pair<ll,ll> PLL;
namespace szhqwq {
template<class T> il void read(T &x) {
x = 0; T f = 1; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
x *= f;
}
il int read() { int x; read(x); return x; }
template<class T,class... Args> il void read(T &x,Args &...x_) { read(x); read(x_...); }
template<class T> il void print(T x) {
if (x < 0) ptc('-'), x = -x;
if (x > 9) print(x / 10); ptc(x % 10 + '0');
}
template<class T,class T_> il void write(T x,T_ ch) { print(x); ptc(ch); }
template<class T,class T_> il void chmax(T &x,T_ y) { x = max(x,y); }
template<class T,class T_> il void chmin(T &x,T_ y) { x = min(x,y); }
template<class T,class T_,class T__> il T qmi(T a,T_ b,T__ p) {
T res = 1; while (b) {
if (b & 1) res = res * a % p;
a = a * a % p; b >>= 1;
} return res;
}
template<class T> il int getinv(T x,T p) { return qmi(x,p - 2,p); }
} using namespace szhqwq;
const int N = 1e5 + 10,inf = 1e9,mod = 998244353;
const ll inff = 1e18;
bool sta;
int n,a[N],VAL = 2e9,rt[N],tot;
struct node { int l,r; ll val; } tr[N << 6];
int le[N],ri[N];
int st[N],top;
il void modify(int &u,int p,int l,int r,int x) {
u = ++ tot; tr[u] = tr[p]; ++ tr[u].val;
if (l == r) return ;
int mid = l + r >> 1;
if (x <= mid) modify(tr[u].l,tr[p].l,l,mid,x);
else modify(tr[u].r,tr[p].r,mid + 1,r,x);
return ;
}
il ll query(int u,int p,int l,int r,int x) {
if (l == r) return tr[u].val - tr[p].val;
int mid = l + r >> 1;
if (x <= mid) return query(tr[u].l,tr[p].l,l,mid,x);
return tr[tr[u].l].val - tr[tr[p].l].val + query(tr[u].r,tr[p].r,mid + 1,r,x);
}
il void solve() {
//------------code------------
read(n); rep(i,1,n) read(a[i]),modify(rt[i],rt[i - 1],1,VAL,a[i]);
rep(i,1,n) {
while (top && a[st[top]] < a[i]) ri[st[top --]] = i;
st[++ top] = i; ri[i] = n + 1;
}
top = 0;
rep1(i,n,1) {
while (top && a[st[top]] <= a[i]) le[st[top --]] = i;
st[++ top] = i;
}
// rep(i,1,n) {
// cerr << le[i] << " " << ri[i] << endl;
// }
ll ret = 0;
rep(i,1,n) {
int lst = ret;
if (i - le[i] < ri[i] - i) rep(j,le[i] + 1,i) ret += query(rt[ri[i] - 1],rt[i - 1],1,VAL,a[i] / a[j]);
else rep(j,i,ri[i] - 1) {
// if (i == 3) cerr << query(rt[i],rt[le[i]],1,VAL,a[i] / a[j]) << " sfdsdg\n";
ret += query(rt[i],rt[le[i]],1,VAL,a[i] / a[j]);
}
// cerr << i << " " << le[i] << " " << ri[i] << " " << ret - lst << endl;
}
write(ret,'\n');
bool ed;
// cerr << (&ed - &sta) / 1024 / 1024 << " MB\n";
// cerr << "Time : " << (db)(end - start) / CLOCKS_PER_SEC << " s" << endl;
return ;
}
il void init() {
return ;
}
signed main() {
// init();
int _ = 1;
// read(_);
while (_ -- ) solve();
return 0;
}