有关主席树的一些 trick

主席树做题记录。

主席树,即可持久化权值线段树。

P3248 [HNOI2016] 树

难爆了这题。题目中会多次把模板树的某个子树放到大树上的某个节点下,我们把这一整个子树看作一个大节点,把模板树、大树分别维护。

具体的,模板树上需要倍增维护两点之间的距离,dfs 序。

大树上需要维护:

  1. 大树上大节点的节点编号的上下界;

  2. 大节点是挂在大树的哪个节点下的;

  3. 该大节点是模板树中以哪个节点为根的子树;

  4. 倍增维护大节点间的距离。

另外我们需要维护大树上的几个操作,然后在大树上倍增求出答案。

  1. 大树上某个节点属于哪个大节点,可以根据大节点编号上下界二分求得;

  2. 大树上某个节点对应的是模板树上哪个节点,这个实际上是求第 \(k\) 小,可以用主席树实现;

接下来深度剖析在大树上倍增求解答案的过程:

\(u\) 是深度大的节点,\(v\) 是深度小的节点。首先需要求大树上的节点 \(u\) 跳到自己所在的大节点的根上的距离,然后再将 \(u\)\(v\) 跳到同一深度,加上大节点与大节点间的距离。如果此时 \(u,v\) 在同一大节点中,则直接大节点内求答案,否则将 \(v\) 跳到自己所在的大节点的根上,然后 \(u,v\) 一起倍增往上跳,最后直接求 \(u,v\) 所在的链的大节点将要跳到他们的 \(lca\)(大节点)上是要经过挂在大树的哪个节点下的之间的距离。

维护好细节就行了。

#include <bits/stdc++.h>

// #define int long long
#define ll long long
#define ull unsigned long long
#define db double
#define ld long double
#define rep(i,l,r) for (int i = (int)(l); i <= (int)(r); ++ i )
#define rep1(i,l,r) for (int i = (int)(l); i >= (int)(r); -- i )
#define il inline
#define fst first
#define snd second
#define ptc putchar
#define Yes ptc('Y'),ptc('e'),ptc('s'),puts("")
#define No ptc('N'),ptc('o'),puts("")
#define YES ptc('Y'),ptc('E'),ptc('S'),puts("")
#define NO ptc('N'),ptc('O'),puts("")
#define pb emplace_back
#define sz(x) (int)(x.size())
#define all(x) x.begin(),x.end()
#define debug() puts("------------------")

using namespace std;
typedef pair<int,int> PII;
typedef pair<int,PII> PIII;
typedef pair<ll,ll> PLL;
namespace szhqwq {
    template<class T> il void read(T &x) {
        x = 0; T f = 1; char ch = getchar();
        while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
        while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
        x *= f;
    }
    il int read() { int x; read(x); return x; }
    il ll readl() { ll x; read(x); return x; }
    template<class T,class... Args> il void read(T &x,Args &...x_) { read(x); read(x_...); }
    template<class T> il void print(T x) {
        if (x < 0) ptc('-'), x = -x; 
        if (x > 9) print(x / 10); ptc(x % 10 + '0');
    }
    template<class T,class T_> il void write(T x,T_ ch) { print(x); ptc(ch); }
    template<class T,class T_> il void chmax(T &x,T_ y) { x = max(x,y); }
    template<class T,class T_> il void chmin(T &x,T_ y) { x = min(x,y); }
    template<class T,class T_,class T__> il T qmi(T a,T_ b,T__ p) {
        T res = 1; while (b) {
            if (b & 1) res = res * a % p;
            a = a * a % p; b >>= 1;
        } return res;
    }
    template<class T> il int getinv(T x,T p) { return qmi(x,p - 2,p); }
} using namespace szhqwq;
bool stpos;
const int N = 1e5 + 10,M = 20,inf = 1e9,mod = 998244353;
const ll inff = 1e18;
int n,m,q;
int h[N],e[N << 1],ne[N << 1],idx;
PLL cp[N];
struct kvsgt {
    int rt[N],tot;
    struct node { int l,r,val; } tr[N << 5];
    il void pushup(int u) {
        tr[u].val = tr[tr[u].l].val + tr[tr[u].r].val;
    }
    
    il void modify(int &u,int p,int l,int r,int x) {
        u = ++ tot;
        tr[u] = tr[p];
        if (l == r) {
            ++ tr[u].val;
            return ;
        }
        int mid = l + r >> 1;
        if (x <= mid) modify(tr[u].l,tr[p].l,l,mid,x);
        else modify(tr[u].r,tr[p].r,mid + 1,r,x);
        pushup(u);
    }

    il int query(int u,int p,int l,int r,int k) {
        if (l == r) return l;
        int mid = l + r >> 1;
        int cntl = tr[tr[u].l].val - tr[tr[p].l].val;
        if (k <= cntl) return query(tr[u].l,tr[p].l,l,mid,k);
        return query(tr[u].r,tr[p].r,mid + 1,r,k - cntl);
    }
} T;
struct basic_tree {
    int d[N],dfn[N],cnt,id[N],in[N],out[N],fa[N][M];
    il void dfs(int u,int f) {
        fa[u][0] = f;
        rep(i,1,17) fa[u][i] = fa[fa[u][i - 1]][i - 1];
        d[u] = d[f] + 1;
        dfn[++ cnt] = u;
        in[u] = cnt;
        for (int i = h[u]; ~i; i = ne[i]) {
            int j = e[i];
            if (j == f) continue;
            dfs(j,u);
        }
        out[u] = cnt;
    }

    il int lca(int x,int y) {
        if (d[x] < d[y]) swap(x,y);
        rep1(i,17,0) if (d[fa[x][i]] >= d[y]) x = fa[x][i];
        if (x == y) return x;
        rep1(i,17,0) if (fa[x][i] != fa[y][i]) x = fa[x][i],y = fa[y][i];
        return fa[x][0];
    }

    il int getdis(int x,int y) {
        return d[x] + d[y] - 2 * d[lca(x,y)];
    }
} BT;
struct big_tree {
    int u,d[N]; ll L[N],R[N],gua[N],zi[N],fa[N][M],dist[N][M];
    il int getbig(ll x) {
        int l = 1,r = u,res = 0;
        while (l <= r) {
            int mid = l + r >> 1;
            if (L[mid] > x) r = mid - 1;
            else l = mid + 1,res = mid;
        }
        return res;
    }

    il int getkth(ll x) {
        int xxx = getbig(x);
        int k = T.query(T.rt[BT.out[zi[xxx]]],T.rt[BT.in[zi[xxx]] - 1],1,n,x - L[xxx] + 1);
        return k;
    }

    il void init() {
        L[++ u] = 1; R[u] = n; zi[u] = 1; d[u] = 1;
        // cerr << "   " << L[u] << " " << R[u] << endl;
        ll now = n + 1;
        rep(i,1,m) {
            int x = cp[i].fst; ll to = cp[i].snd; int f = getbig(to);
            L[++ u] = now; R[u] = now + BT.out[x] - BT.in[x]; now = R[u] + 1ll;
            gua[u] = to; zi[u] = x; fa[u][0] = f;
            dist[u][0] = 1ll + BT.d[getkth(to)] - BT.d[zi[f]];
            // cerr << u << " " << f << endl;
            d[u] = d[f] + 1;
            rep(j,1,17) {
                fa[u][j] = fa[fa[u][j - 1]][j - 1];
                dist[u][j] = dist[fa[u][j - 1]][j - 1] + dist[u][j - 1];
            }
        }
        // debug();
        // rep(i,1,m + 1) {
        //     rep(j,0,3) cout << fa[i][j] << " ";
        //     cout << endl;
        // }
        // debug();
    }

    il ll getans(ll x,ll y) {
        ll res = 0;
        int xx = getbig(x),yy = getbig(y);
        // cout << "aaa " << x << " " << y << " " << xx << " " << yy << endl;
        if (xx == yy) return BT.getdis(getkth(x),getkth(y));
        if (d[xx] < d[yy]) swap(x,y),swap(xx,yy);
        res += BT.d[getkth(x)] - BT.d[zi[xx]]; // 大节点中小节点跳到大节点根上面去
        // cout << "!!! " << res << " " << getkth(x) << " " << zi[xx] << endl;
        rep1(i,17,0) // 倍增跳深度大的节点
            if (d[fa[xx][i]] > d[yy]) {
                res += dist[xx][i];
                // cout << "??? " << i << " " << d[fa[xx][i]] << " " << fa[xx][i] << " " << dist[xx][i] << endl;
                xx = fa[xx][i];
            }
        if (getbig(gua[xx]) == yy) {
            res += 1ll + 1ll * BT.getdis(getkth(gua[xx]),getkth(y)); // 特殊情况
            return res;
        }
        if (d[xx] > d[yy]) res += dist[xx][0],xx = fa[xx][0];
        // cout << "??? " << res << endl;
        res += BT.d[getkth(y)] - BT.d[zi[yy]];
        // if (xx == yy) return res; // 同一个大节点内小节点间距离,此时 x 已经在根上了,所以即 y 到根的距离
        rep1(i,17,0) // 倍增同时跳
            if (fa[xx][i] != fa[yy][i]) {
                res += dist[xx][i] + dist[yy][i];
                xx = fa[xx][i]; yy = fa[yy][i];
            }
        res += 2ll + 1ll * BT.getdis(getkth(gua[xx]),getkth(gua[yy])); // 是这样的
        return res;
    }
} BGT;

il void add(int a,int b) {
    e[idx] = b;
    ne[idx] = h[a];
    h[a] = idx ++;
}

il void solve() {
    //------------code------------
    // clock_t start = clock();
    memset(h,-1,sizeof h);
    read(n,m,q);
    rep(i,1,n - 1) {
        int a,b; read(a,b);
        add(a,b); add(b,a);
    }
    rep(i,1,m) cp[i] = {readl(),readl()};
    BT.dfs(1,0);
    rep(i,1,n) T.modify(T.rt[i],T.rt[i - 1],1,n,BT.dfn[i]);
    BGT.init();
    while (q -- ) write(BGT.getans(readl(),readl()),'\n');
    // clock_t end = clock();
    // cerr << "Time : " << (db)(end - start) << " ms" << endl;
    // bool endpos;
    // cerr << (&endpos - &stpos) / 1024 / 1024 << endl;
    return ;
}

il void init() {
    return ;
}

signed main() {
    // init();
    int _ = 1;
    // read(_);
    while (_ -- ) solve();
    return 0;
}

P3293 [SCOI2016] 美味

好题。

每次的 \(b,x\) 是固定的,相当于要找到一个最优的 \(a_i + x\) 使其与 \(b\) 异或值最大。我们贪心地从高位到低位枚举,并记录当前的最优的 \(a_i + x\),对于 \(b\) 来说,如果这一位为 1 则只要存在 \(a_i \in [...000000...,...011111...]\) 那么 \(b \oplus (a_i + x)\) 在这一位为 1,即 \(a_i \in [now - x,now - x + 2^i - 1]\)

这一位为 0 同理,\(a_i \in [now - x + 2^i,now - x + 2^{i + 1} - 1]\)

主席树每次看有没有这个区间的满足条件的 \(a_i\) 即可。

#include <bits/stdc++.h>

// #define int long long
#define ll long long
#define ull unsigned long long
#define db double
#define ld long double
#define rep(i,l,r) for (int i = (int)(l); i <= (int)(r); ++ i )
#define rep1(i,l,r) for (int i = (int)(l); i >= (int)(r); -- i )
#define il inline
#define fst first
#define snd second
#define ptc putchar
#define Yes ptc('Y'),ptc('e'),ptc('s'),puts("")
#define No ptc('N'),ptc('o'),puts("")
#define YES ptc('Y'),ptc('E'),ptc('S'),puts("")
#define NO ptc('N'),ptc('O'),puts("")
#define pb emplace_back
#define sz(x) (int)(x.size())
#define all(x) x.begin(),x.end()
#define get(x) ((x - 1) / len + 1)
#define debug() puts("------------------")

using namespace std;
typedef pair<int,int> PII;
typedef pair<int,PII> PIII;
typedef pair<ll,ll> PLL;
namespace szhqwq {
    template<class T> il void read(T &x) {
        x = 0; T f = 1; char ch = getchar();
        while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
        while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
        x *= f;
    }
    il int read() { int x; read(x); return x; }
    template<class T,class... Args> il void read(T &x,Args &...x_) { read(x); read(x_...); }
    template<class T> il void print(T x) {
        if (x < 0) ptc('-'), x = -x; 
        if (x > 9) print(x / 10); ptc(x % 10 + '0');
    }
    template<class T,class T_> il void write(T x,T_ ch) { print(x); ptc(ch); }
    template<class T,class T_> il void chmax(T &x,T_ y) { x = max(x,y); }
    template<class T,class T_> il void chmin(T &x,T_ y) { x = min(x,y); }
    template<class T,class T_,class T__> il T qmi(T a,T_ b,T__ p) {
        T res = 1; while (b) {
            if (b & 1) res = res * a % p;
            a = a * a % p; b >>= 1;
        } return res;
    }
    template<class T> il int getinv(T x,T p) { return qmi(x,p - 2,p); }
} using namespace szhqwq;
const int N = 2e5 + 10,inf = 1e9,mod = 998244353;
const ll inff = 1e18;
int n,m,a[N],rt[N],tot,VAL = 300000;
struct node { int l,r,val; } tr[N << 5];

il void pushup(int u) {
    tr[u].val = tr[tr[u].l].val + tr[tr[u].r].val;
}

il void modify(int &u,int p,int l,int r,int x) {
    u = ++ tot; tr[u] = tr[p];
    if (l == r) {
        ++ tr[u].val;
        return ;
    }
    int mid = l + r >> 1;
    if (x <= mid) modify(tr[u].l,tr[p].l,l,mid,x);
    else modify(tr[u].r,tr[p].r,mid + 1,r,x);
    pushup(u);
}

il int query(int u,int p,int l,int r,int L,int R) {
    if (L <= l && r <= R) return tr[u].val - tr[p].val;
    int ret = 0;
    int mid = l + r >> 1;
    if (L <= mid) ret += query(tr[u].l,tr[p].l,l,mid,L,R);
    if (mid < R) ret += query(tr[u].r,tr[p].r,mid + 1,r,L,R);
    return ret;
}

il void solve() {
    //------------code------------
    read(n,m);
    rep(i,1,n) {
        read(a[i]);
        modify(rt[i],rt[i - 1],0,VAL,a[i]);
    }
    while (m -- ) {
        int b,x,l,r; read(b,x,l,r);
        int ret = 0;
        rep1(i,17,0) {
            if (b >> i & 1) {
                int val = query(rt[r],rt[l - 1],0,VAL,max(0,ret - x),min(VAL,ret - x + (1 << i) - 1));
                if (!val) ret += (1 << i);
            } else {
                int val = query(rt[r],rt[l - 1],0,VAL,max(0,ret - x + (1 << i)),min(VAL,ret - x + (1 << (i + 1)) - 1));
                if (val) ret += (1 << i);
            }
        }
        write(ret ^ b,'\n');
    }
    // cerr << "Time : " << (db)(end - start) / CLOCKS_PER_SEC << " s" << endl;
    return ;
}

il void init() {
    return ;
}

signed main() {
    // init();
    int _ = 1;
    // read(_);
    while (_ -- ) solve();
    return 0;
}

P3168 [CQOI2015] 任务查询系统

板子。

以时间为根把优先级放到主席树上去查询就行。差分一下就可以了。

#include <bits/stdc++.h>

// #define int long long
#define ll long long
#define ull unsigned long long
#define db double
#define ld long double
#define rep(i,l,r) for (int i = (int)(l); i <= (int)(r); ++ i )
#define rep1(i,l,r) for (int i = (int)(l); i >= (int)(r); -- i )
#define il inline
#define fst first
#define snd second
#define ptc putchar
#define Yes ptc('Y'),ptc('e'),ptc('s'),puts("")
#define No ptc('N'),ptc('o'),puts("")
#define YES ptc('Y'),ptc('E'),ptc('S'),puts("")
#define NO ptc('N'),ptc('O'),puts("")
#define pb emplace_back
#define sz(x) (int)(x.size())
#define all(x) x.begin(),x.end()
#define get(x) ((x - 1) / len + 1)
#define debug() puts("------------------")

using namespace std;
typedef pair<int,int> PII;
typedef pair<int,PII> PIII;
typedef pair<ll,ll> PLL;
namespace szhqwq {
    template<class T> il void read(T &x) {
        x = 0; T f = 1; char ch = getchar();
        while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
        while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
        x *= f;
    }
    il int read() { int x; read(x); return x; }
    template<class T,class... Args> il void read(T &x,Args &...x_) { read(x); read(x_...); }
    template<class T> il void print(T x) {
        if (x < 0) ptc('-'), x = -x; 
        if (x > 9) print(x / 10); ptc(x % 10 + '0');
    }
    template<class T,class T_> il void write(T x,T_ ch) { print(x); ptc(ch); }
    template<class T,class T_> il void chmax(T &x,T_ y) { x = max(x,y); }
    template<class T,class T_> il void chmin(T &x,T_ y) { x = min(x,y); }
    template<class T,class T_,class T__> il T qmi(T a,T_ b,T__ p) {
        T res = 1; while (b) {
            if (b & 1) res = res * a % p;
            a = a * a % p; b >>= 1;
        } return res;
    }
    template<class T> il int getinv(T x,T p) { return qmi(x,p - 2,p); }
} using namespace szhqwq;
const int N = 1e5 + 10,inf = 1e9,mod = 998244353;
const ll inff = 1e18;
int n,m,rt[N],tot,VAL = 2e7;
struct node { int l,r,c; ll val; } tr[N << 6];
vector<int> v; vector<PII> upd[N];

il void pushup(int u) {
    tr[u].c = tr[tr[u].l].c + tr[tr[u].r].c;
    tr[u].val = tr[tr[u].l].val + tr[tr[u].r].val;
}

il void modify(int &u,int p,int l,int r,int x,int k) {
    u = ++ tot; tr[u] = tr[p];
    if (l == r) {
        tr[u].c += k; 
        tr[u].val += 1ll * x * k;
        return ;
    }
    int mid = l + r >> 1;
    if (x <= mid) modify(tr[u].l,tr[p].l,l,mid,x,k);
    else modify(tr[u].r,tr[p].r,mid + 1,r,x,k);
    pushup(u);
}

il ll query(int u,int l,int r,ll k) {
    if (l == r) return tr[u].val / tr[u].c * k;
    int mid = l + r >> 1;
    if (k <= tr[tr[u].l].c) return query(tr[u].l,l,mid,k);
    return tr[tr[u].l].val + query(tr[u].r,mid + 1,r,k - tr[tr[u].l].c);
}

il void solve() {
    //------------code------------
    read(m,n);
    rep(i,1,m) {
        int s,e,p; read(s,e,p);
        upd[s].pb(p,1); upd[e + 1].pb(p,-1);
        v.pb(p);
    }
    sort(all(v)); v.erase(unique(all(v)),v.end());
    // VAL = sz(v);
    rep(i,1,n) {
        rt[i] = rt[i - 1];
        for (auto x : upd[i]) {
            // cout << x.fst << " " << x.snd << endl;
            // x.fst = lower_bound(all(v),x.fst) - v.begin();
            // cerr << i << " " << x.fst << endl;
            modify(rt[i],rt[i],0,VAL,x.fst,x.snd);
        }
    }
    ll pre = 1;
    rep(i,1,n) {
        int x,a,b,c; read(x,a,b,c);
        ll k = 1ll + 1ll * (a * pre + b) % c;
        // cerr << k << endl;
        if (k > tr[rt[x]].c) write(pre = tr[rt[x]].val,'\n');
        else write(pre = query(rt[x],0,VAL,k),'\n');
    }
    // cerr << "Time : " << (db)(end - start) / CLOCKS_PER_SEC << " s" << endl;
    return ;
}

il void init() {
    return ;
}

signed main() {
    // init();
    int _ = 1;
    // read(_);
    while (_ -- ) solve();
    return 0;
}

CF813E Army Creation

妙妙题。一个好的 trick。

出现次数不好限制,但是我们可以把当前这个数之前出现第 \(k\) 次的位置(不包括当前这个位置)放到主席树上面去,不足 \(k\) 次的放 0。每次询问直接询问 \([0,l - 1]\) 有多少个数即可。

#include <bits/stdc++.h>

// #define int long long
#define ll long long
#define ull unsigned long long
#define db double
#define ld long double
#define rep(i,l,r) for (int i = (int)(l); i <= (int)(r); ++ i )
#define rep1(i,l,r) for (int i = (int)(l); i >= (int)(r); -- i )
#define il inline
#define fst first
#define snd second
#define ptc putchar
#define Yes ptc('Y'),ptc('e'),ptc('s'),puts("")
#define No ptc('N'),ptc('o'),puts("")
#define YES ptc('Y'),ptc('E'),ptc('S'),puts("")
#define NO ptc('N'),ptc('O'),puts("")
#define pb emplace_back
#define sz(x) (int)(x.size())
#define all(x) x.begin(),x.end()
#define get(x) ((x - 1) / len + 1)
#define debug() puts("------------------")

using namespace std;
typedef pair<int,int> PII;
typedef pair<int,PII> PIII;
typedef pair<ll,ll> PLL;
namespace szhqwq {
    template<class T> il void read(T &x) {
        x = 0; T f = 1; char ch = getchar();
        while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
        while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
        x *= f;
    }
    il int read() { int x; read(x); return x; }
    template<class T,class... Args> il void read(T &x,Args &...x_) { read(x); read(x_...); }
    template<class T> il void print(T x) {
        if (x < 0) ptc('-'), x = -x; 
        if (x > 9) print(x / 10); ptc(x % 10 + '0');
    }
    template<class T,class T_> il void write(T x,T_ ch) { print(x); ptc(ch); }
    template<class T,class T_> il void chmax(T &x,T_ y) { x = max(x,y); }
    template<class T,class T_> il void chmin(T &x,T_ y) { x = min(x,y); }
    template<class T,class T_,class T__> il T qmi(T a,T_ b,T__ p) {
        T res = 1; while (b) {
            if (b & 1) res = res * a % p;
            a = a * a % p; b >>= 1;
        } return res;
    }
    template<class T> il int getinv(T x,T p) { return qmi(x,p - 2,p); }
} using namespace szhqwq;
const int N = 1e5 + 10,inf = 1e9,mod = 998244353;
const ll inff = 1e18;
int n,k,q,a[N],rt[N],tot,VAL = 100000;
struct node { int l,r,val; } tr[N << 5];
vector<int> v[N];

il void pushup(int u) {
    tr[u].val = tr[tr[u].l].val + tr[tr[u].r].val;
}

il void modify(int &u,int p,int l,int r,int x) {
    u = ++ tot; tr[u] = tr[p];
    if (l == r) return ++ tr[u].val,void();
    int mid = l + r >> 1;
    if (x <= mid) modify(tr[u].l,tr[p].l,l,mid,x);
    else modify(tr[u].r,tr[p].r,mid + 1,r,x);
    pushup(u);
}

il int query(int u,int p,int l,int r,int L,int R) {
    if (L <= l && r <= R) return tr[u].val - tr[p].val;
    int mid = l + r >> 1;
    int ret = 0;
    if (L <= mid) ret += query(tr[u].l,tr[p].l,l,mid,L,R);
    if (mid < R) ret += query(tr[u].r,tr[p].r,mid + 1,r,L,R);
    return ret;
}

il void solve() {
    //------------code------------
    read(n,k);
    rep(i,1,n) {
        read(a[i]); v[a[i]].pb(i);
        // cerr << i << " " << (sz(v[a[i]]) <= k ? 0 : v[a[i]][sz(v[a[i]]) - k - 1]) << " " << sz(v[a[i]]) - k - 1 << " sdfsdfsdf\n";
        modify(rt[i],rt[i - 1],0,VAL,sz(v[a[i]]) <= k ? 0 : v[a[i]][sz(v[a[i]]) - k - 1]);
    }
    read(q);
    int pre = 0;
    while (q -- ) {
        int l,r,x,y; read(x,y);
        l = ((x + pre) % n) + 1;
        r = ((y + pre) % n) + 1;
        if (l > r) swap(l,r);
        // cerr << l << " " << r << endl;
        write(pre = query(rt[r],rt[l - 1],0,VAL,0,l - 1),'\n');
    }
    // cerr << "Time : " << (db)(end - start) / CLOCKS_PER_SEC << " s" << endl;
    return ;
}

il void init() {
    return ;
}

signed main() {
    // init();
    int _ = 1;
    // read(_);
    while (_ -- ) solve();
    return 0;
}

P4755 Beautiful Pair

妙妙题。

考虑转化题目给定的问题,可以固定最大值之后再去找不超过最大值的 \(a_i \times a_j\)

考虑单调栈,对于 \(a_i\) 求出向左第一个 \(\geq a_i\) 的位置 \(le_i\) 以及向右第一个 \(> a_i\) 的位置 \(ri_i\),注意此处一方取等一方不取的原因是为了不重复计算,取等的方向反过来也是可行的。

考虑去枚举从 \(i\) 向左的区间和向右的区间中长度更小的区间,对于 \(a_j\) 我们只需要看另一边 \(a_k \leq \frac{a_i}{a_j}\)\(k\) 的个数即可。这样均摊下来是 \(\log n\) 的。时间复杂度 \(\mathcal{O(n \log^2 n)}\)

#include <bits/stdc++.h>

// #define int long long
#define ll long long
#define ull unsigned long long
#define db double
#define ld long double
#define rep(i,l,r) for (int i = (int)(l); i <= (int)(r); ++ i )
#define rep1(i,l,r) for (int i = (int)(l); i >= (int)(r); -- i )
#define il inline
#define fst first
#define snd second
#define ptc putchar
#define Yes ptc('Y'),ptc('e'),ptc('s'),puts("")
#define No ptc('N'),ptc('o'),puts("")
#define YES ptc('Y'),ptc('E'),ptc('S'),puts("")
#define NO ptc('N'),ptc('O'),puts("")
#define pb emplace_back
#define sz(x) (int)(x.size())
#define all(x) x.begin(),x.end()
#define get(x) ((x - 1) / len + 1)
#define debug() puts("------------------")

using namespace std;
typedef pair<int,int> PII;
typedef pair<int,PII> PIII;
typedef pair<ll,ll> PLL;
namespace szhqwq {
    template<class T> il void read(T &x) {
        x = 0; T f = 1; char ch = getchar();
        while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
        while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
        x *= f;
    }
    il int read() { int x; read(x); return x; }
    template<class T,class... Args> il void read(T &x,Args &...x_) { read(x); read(x_...); }
    template<class T> il void print(T x) {
        if (x < 0) ptc('-'), x = -x; 
        if (x > 9) print(x / 10); ptc(x % 10 + '0');
    }
    template<class T,class T_> il void write(T x,T_ ch) { print(x); ptc(ch); }
    template<class T,class T_> il void chmax(T &x,T_ y) { x = max(x,y); }
    template<class T,class T_> il void chmin(T &x,T_ y) { x = min(x,y); }
    template<class T,class T_,class T__> il T qmi(T a,T_ b,T__ p) {
        T res = 1; while (b) {
            if (b & 1) res = res * a % p;
            a = a * a % p; b >>= 1;
        } return res;
    }
    template<class T> il int getinv(T x,T p) { return qmi(x,p - 2,p); }
} using namespace szhqwq;
const int N = 1e5 + 10,inf = 1e9,mod = 998244353;
const ll inff = 1e18;
bool sta;
int n,a[N],VAL = 2e9,rt[N],tot;
struct node { int l,r; ll val; } tr[N << 6];
int le[N],ri[N];
int st[N],top;

il void modify(int &u,int p,int l,int r,int x) {
    u = ++ tot; tr[u] = tr[p]; ++ tr[u].val;
    if (l == r) return ;
    int mid = l + r >> 1;
    if (x <= mid) modify(tr[u].l,tr[p].l,l,mid,x);
    else modify(tr[u].r,tr[p].r,mid + 1,r,x);
    return ;
}

il ll query(int u,int p,int l,int r,int x) {
    if (l == r) return tr[u].val - tr[p].val;
    int mid = l + r >> 1;
    if (x <= mid) return query(tr[u].l,tr[p].l,l,mid,x);
    return tr[tr[u].l].val - tr[tr[p].l].val + query(tr[u].r,tr[p].r,mid + 1,r,x);
}

il void solve() {
    //------------code------------
    read(n); rep(i,1,n) read(a[i]),modify(rt[i],rt[i - 1],1,VAL,a[i]);
    rep(i,1,n) {
        while (top && a[st[top]] < a[i]) ri[st[top --]] = i;
        st[++ top] = i; ri[i] = n + 1;
    }
    top = 0;
    rep1(i,n,1) {
        while (top && a[st[top]] <= a[i]) le[st[top --]] = i;
        st[++ top] = i;
    }
    // rep(i,1,n) {
    //     cerr << le[i] << " " << ri[i] << endl;
    // }
    ll ret = 0;
    rep(i,1,n) {
        int lst = ret;
        if (i - le[i] < ri[i] - i) rep(j,le[i] + 1,i) ret += query(rt[ri[i] - 1],rt[i - 1],1,VAL,a[i] / a[j]);
        else rep(j,i,ri[i] - 1) {
            // if (i == 3) cerr << query(rt[i],rt[le[i]],1,VAL,a[i] / a[j]) << " sfdsdg\n";
            ret += query(rt[i],rt[le[i]],1,VAL,a[i] / a[j]);
        }
        // cerr << i << " " << le[i] << " " << ri[i] << " " << ret - lst << endl;
    }
    write(ret,'\n');
    bool ed;
    // cerr << (&ed - &sta) / 1024 / 1024 << " MB\n";
    // cerr << "Time : " << (db)(end - start) / CLOCKS_PER_SEC << " s" << endl;
    return ;
}

il void init() {
    return ;
}

signed main() {
    // init();
    int _ = 1;
    // read(_);
    while (_ -- ) solve();
    return 0;
}
posted @ 2024-08-16 17:18  songszh  阅读(10)  评论(0编辑  收藏  举报