HDU6964 I Love Counting(2021HDU多校第二场1004)(平衡树/树状数组+二维数点+字典树)
题意:
给出一个序列。
每次询问一个区间内有多少个不同的数异或a<=b。
题解:
首先有个前置知识,就是不带区间的情况下有多少个不同的数异或a<=b,这是一个经典的字典树上DP的模型,找到对应的子树统计信息即可,这里不再赘述。
然后考虑区间,如果把不同的数这个条件去掉,可以直接上可持久化字典树。
但是任何一个可持久化数据结构都无法处理不同这个条件。
做法一
比赛的时候想了一个莫队套字典树的做法,就是在莫队的过程中维护一颗字典树,这个思路比较好想,时间复杂度\(O(nlogn\sqrt{n})\)
比赛中居然有人用这个时间复杂度卡过去了?
但是也有一种莫队好像可以把log去掉,不得不说是真的nb
做法二
对字典树上的每个节点维护子树内所有数的前驱。这里我用的Splay树维护每个节点的前驱集合。
然后从左往右更新数组,先把与当前元素有关的所有节点的Splay更新,然后处理以当前下标为右端点的询问,与每个询问相关的子树数量是log级的,对这些子树的Splay查询比询问的左端点大的前驱数量。
求和就是答案,这样搞时间复杂度是\(O(nlognlogn)\)的,但是对每个节点维护一颗Splay,好像复杂度并不能均摊,在HDU上稳T,在luogu上跑1.67s。
做法三
在做法二的基础上用空间换时间。对每个节点维护两个链表,一个表示与这个节点相关的所有数和它们的位置,一个表示与这个节点相关的询问。
然后遍历所有节点,先把所有节点的询问按右端点从小到大排序,然后从左往右遍历询问,同时在节点的元素链表里维护一个指针,每次把出现位置小于等于当前询问右端点的前驱全部更新到Splay树里,然后在Splay上询问比左端点大的数的数量。
这里由于不断要对Splay树做插入删除的操作,导致内存爆炸,还要手写一个垃圾回收。
时间复杂度\(O(nlognlogn)\),但是只用一颗Splay树,常数得到进一步优化,在HDU上1.9s AC。
做法四
在做法二的基础上,用树状数组代替Splay树,树状数组常数是真的很小,问题得以解决。在洛谷上1.25s,在HDU上1.4s。这好像也是std的做法。
总结
这道题没用什么很高级的思想,就是不断的通过一些小技巧优化常数,真的学到许多,是一道很好的数据结构练习题。
代码
这里贴上做法三和做法四的代码
做法三:
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+100;
const int M=1e5*30;
int fa[M],ch[M][2],val[M],cnt[M],sz[M],tot,hs[M],ts;
struct Splay {
int rt=0;
void maintain (int x) {
sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+cnt[x];
}
bool get (int x) {
return x==ch[fa[x]][1];
}
void clear (int x) {
ch[x][0]=ch[x][1]=fa[x]=val[x]=sz[x]=cnt[x]=0;
}
void rotate (int x) {
int y=fa[x];
int z=fa[y];
int chk=get(x);
ch[y][chk]=ch[x][chk^1];
if (ch[x][chk^1]) {
fa[ch[x][chk^1]]=y;
}
ch[x][chk^1]=y;
fa[y]=x;
fa[x]=z;
if (z) {
ch[z][y==ch[z][1]]=x;
}
maintain(x);
maintain(y);
}
void splay (int x) {
for (int f=fa[x];f=fa[x];rotate(x)) {
if (fa[f]) {
rotate(get(x)==get(f)?f:x);
}
}
rt=x;
}
void ins (int k) {
if (!rt) {
if (ts==0) {
val[++tot]=k;
cnt[tot]++;
rt=tot;
}
else {
val[hs[ts]]=k;
cnt[hs[ts]]++;
rt=hs[ts];
ts--;
}
maintain(rt);
return;
}
int cur=rt,f=0;
while (1) {
if (val[cur]==k) {
cnt[cur]++;
maintain(cur);
maintain(f);
splay(cur);
break;
}
f=cur;
cur=ch[cur][val[cur]<k];
if (!cur) {
if (ts==0) {
val[++tot]=k;
cnt[tot]++;
fa[tot]=f;
ch[f][val[f]<k]=tot;
maintain(tot);
maintain(f);
splay(tot);
}
else {
val[hs[ts]]=k;
cnt[hs[ts]]++;
fa[hs[ts]]=f;
ch[f][val[f]<k]=tot;
maintain(hs[ts]);
maintain(f);
splay(hs[ts]);
ts--;
}
break;
}
}
}
int rk (int k) {
int res=0;
int cur=rt;
while (1) {
if (!cur) {
return res+1;
}
if (k<val[cur]) {
cur=ch[cur][0];
}
else
{
res+=sz[ch[cur][0]];
if (k==val[cur]) {
splay(cur);
return res+1;
}
res+=cnt[cur];
cur=ch[cur][1];
}
}
}
int pre () {
int cur=ch[rt][0];
if (!cur) return cur;
while (ch[cur][1]) {
cur=ch[cur][1];
}
splay(cur);
return cur;
}
void del (int k) {
rk(k);
if (cnt[rt]>1) {
cnt[rt]--;
maintain(rt);
return;
}
if (!ch[rt][0]&&!ch[rt][1]) {
clear(rt);
rt=0;
return;
}
if (!ch[rt][0]) {
int cur=rt;
rt=ch[rt][1];
fa[rt]=0;
clear(cur);
return;
}
if (!ch[rt][1]) {
int cur=rt;
rt=ch[rt][0];
fa[rt]=0;
clear(cur);
return;
}
int cur=rt;
int x=pre();
fa[ch[cur][1]]=x;
ch[x][1]=ch[cur][1];
clear(cur);
maintain(rt);
}
};
Splay * splay;
int tr[M][2],tol;
vector<int> g[maxn];//数字i的二进制形式
void zh (int x) {
if (g[x].size()) return;
int u=x;
while (x) {
g[u].push_back(x%2);
x/=2;
}
while (g[u].size()<17) g[u].push_back(0);
int uu=0;
for (int i=16;i>=0;i--) {
if (!tr[uu][g[u][i]]) tr[uu][g[u][i]]=++tol;
uu=tr[uu][g[u][i]];
}
}
int Pre[maxn];//保存每个数的前驱
int n,a[maxn];
struct qnode {
int id,a,b,l,r;
bool operator < (const qnode &x) const {
return r<x.r;
}
};
int ans[maxn];
vector<pair<int,int> > ys[M];//对每个节点维护一个元素数组
vector<qnode> xy[M];//对每个节点维护一个询问数组
void insert (int u,int x,int dep,int i) {
if (u) {
ys[u].push_back(make_pair(a[i],i));
}
if (dep<0) return;
insert(tr[u][g[x][dep]],x,dep-1,i);
}
void query (int u,int a,int b,int dep,int l,int r,int id) {
//在字典树上找到对应的子树
if (dep<0) {
//return splay[u].rk(r+1)-1;
xy[u].push_back({id,a,b,l,r});//这个询问涉及到的节点
}
if (g[a][dep]==1&&g[b][dep]==0) {
if (tr[u][1])query(tr[u][1],a,b,dep-1,l,r,id);
}
else if (g[a][dep]==1&&g[b][dep]==1) {
if (tr[u][1]) xy[tr[u][1]].push_back({id,a,b,l,r});
if (tr[u][0])query(tr[u][0],a,b,dep-1,l,r,id);
}
else if (g[a][dep]==0&&g[b][dep]==1) {
if (tr[u][0]) xy[tr[u][0]].push_back({id,a,b,l,r});
if (tr[u][1])query(tr[u][1],a,b,dep-1,l,r,id);
}
else if (g[a][dep]==0&&g[b][dep]==0) {
if (tr[u][0])query(tr[u][0],a,b,dep-1,l,r,id);
}
}
inline int read()
{
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9') {if(c=='-') f=-1;c=getchar();}
while (c>='0'&&c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
return x*f;
}
int main () {
//100000*20*20*10
splay=new Splay();
n=read();
for (int i=1;i<=n;i++) a[i]=read();
for (int i=1;i<=n;i++) {
zh(a[i]);
}
int m;
m=read();
for (int i=1;i<=m;i++) {
int l,r,A,B;
//scanf("%d%d%d%d",&l,&r,&A,&B);
l=read();r=read();A=read();B=read();
zh(A);
zh(B);
query(0,A,B,16,l,r,i);
}
for (int i=1;i<=n;i++) {
insert(0,a[i],16,i);
}
for (int i=1;i<=tol;i++) {
//遍历每个节点
sort(xy[i].begin(),xy[i].end());
int l=0;
for (qnode it:xy[i]) {
while (l<ys[i].size()&&ys[i][l].second<=it.r) {
int x=ys[i][l].first;
if (Pre[x]) splay->del(Pre[x]);
Pre[x]=ys[i][l].second;
splay->ins(Pre[x]);
l++;
}
ans[it.id]+=splay->rk(1e9)-splay->rk(it.l);
}
for (pair<int,int> it:ys[i]) splay->del(Pre[it.first]),Pre[it.first]=0;
}
for (int i=1;i<=m;i++) printf("%d\n",ans[i]);
}
做法四:
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+100;
const int M=1e5*30;
int c[maxn];
int lowbit (int x) {
return x&-x;
}
void up (int p,int v) {
for (int i=p;i<maxn;i+=lowbit(i)) c[i]+=v;
}
int getsum (int p) {
int ans=0;
for (int i=p;i;i-=lowbit(i)) ans+=c[i];
return ans;
}
int tr[M][2],tol;
vector<int> g[maxn];//数字i的二进制形式
void zh (int x) {
if (g[x].size()) return;
int u=x;
while (x) {
g[u].push_back(x%2);
x/=2;
}
while (g[u].size()<17) g[u].push_back(0);
int uu=0;
for (int i=16;i>=0;i--) {
if (!tr[uu][g[u][i]]) tr[uu][g[u][i]]=++tol;
uu=tr[uu][g[u][i]];
}
}
int Pre[maxn];//保存每个数的前驱
int n,a[maxn];
struct qnode {
int id,a,b,l,r;
bool operator < (const qnode &x) const {
return r<x.r;
}
};
int ans[maxn];
vector<pair<int,int> > ys[M];//对每个节点维护一个元素数组
vector<qnode> xy[M];//对每个节点维护一个询问数组
void insert (int u,int x,int dep,int i) {
if (u) {
ys[u].push_back(make_pair(a[i],i));
}
if (dep<0) return;
insert(tr[u][g[x][dep]],x,dep-1,i);
}
void query (int u,int a,int b,int dep,int l,int r,int id) {
//在字典树上找到对应的子树
if (dep<0) {
//return splay[u].rk(r+1)-1;
xy[u].push_back({id,a,b,l,r});//这个询问涉及到的节点
}
if (g[a][dep]==1&&g[b][dep]==0) {
if (tr[u][1])query(tr[u][1],a,b,dep-1,l,r,id);
}
else if (g[a][dep]==1&&g[b][dep]==1) {
if (tr[u][1]) xy[tr[u][1]].push_back({id,a,b,l,r});
if (tr[u][0])query(tr[u][0],a,b,dep-1,l,r,id);
}
else if (g[a][dep]==0&&g[b][dep]==1) {
if (tr[u][0]) xy[tr[u][0]].push_back({id,a,b,l,r});
if (tr[u][1])query(tr[u][1],a,b,dep-1,l,r,id);
}
else if (g[a][dep]==0&&g[b][dep]==0) {
if (tr[u][0])query(tr[u][0],a,b,dep-1,l,r,id);
}
}
inline int read()
{
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9') {if(c=='-') f=-1;c=getchar();}
while (c>='0'&&c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
return x*f;
}
int main () {
//100000*20*20*10
n=read();
for (int i=1;i<=n;i++) a[i]=read();
for (int i=1;i<=n;i++) {
zh(a[i]);
}
int m;
m=read();
for (int i=1;i<=m;i++) {
int l,r,A,B;
//scanf("%d%d%d%d",&l,&r,&A,&B);
l=read();r=read();A=read();B=read();
zh(A);
zh(B);
query(0,A,B,16,l,r,i);
}
for (int i=1;i<=n;i++) {
insert(0,a[i],16,i);
}
for (int i=1;i<=tol;i++) {
//遍历每个节点
sort(xy[i].begin(),xy[i].end());
int l=0;
for (qnode it:xy[i]) {
while (l<ys[i].size()&&ys[i][l].second<=it.r) {
int x=ys[i][l].first;
if (Pre[x]) up(Pre[x],-1);
Pre[x]=ys[i][l].second;
up(Pre[x],1);
l++;
}
ans[it.id]+=getsum(n+1)-getsum(it.l-1);
}
for (pair<int,int> it:ys[i]) if (Pre[it.first])up(Pre[it.first],-1),Pre[it.first]=0;
}
for (int i=1;i<=m;i++) printf("%d\n",ans[i]);
}