CF1261F Xor-Set
一、题目
二、解法
从计算答案的角度入手,我们不能逐个数地考虑它们能否构造出来,但是以防算重我们需要以异或结果的数为主体来考虑,建议给出的数是区间的形式,那么我们考虑一段一段地考虑数。
具体来说我们需要利用拆位的思想,我们将给定的区间分解成 \([k\cdot 2^y,(k+1)\cdot 2^y)\) 的形式,也就是后面 \(y\) 位是从 \(00..00\) 到 \(11..11\) 的全排列,那么任意两个拆位后的区间异或之后得到的结果还是区间,因为前面的位都是直接异或的结果,后面的位得到的是全排列。
可以把原区间丢到权值线段树上完成拆分的操作,但是枚举两两配对会得到 \(O(n^2\log^2 a)\) 个区间,再对这些区间求并显然复杂度爆炸,需要进一步优化。
观察到上面的求并是有很多无效操作的,也就是对于两个长度分别为 \(2^A,2^B(A>B)\) 的区间,后 \(A\) 位都显示为全排列,所以很多 \(B\) 不同的区间也会有相同的效果。那么我们把长度为 \(B\) 的扩充成 \(A\),尝试把多次合并成一次做。
扩充长度对应到线段树上就是跳父亲,那么优化考虑只做线段树上深度相同的区间,也就是对于第一个序列我们把区间放在线段树的终止节点上,第二个序列则把区间放在线段树经过的所有节点上,先做一遍再对换过来即可。
时间复杂度 \(O(n^2\log a)\)
//My heart travels at the speed of light
#include <cstdio>
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
const int M = 100005;
#define pb push_back
#define int long long
const int inf = (1ll<<60)-1;
const int MOD = 998244353;
const int inv2 = (MOD+1)/2;
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,m,ans,cnt,rt,ls[M],rs[M];
struct node
{
int l,r;
bool operator < (const node &b) const
{
if(l==b.l) return r<b.r;
return l<b.l;
}
}a[M],b[M];
vector<node> p[70],q[70],v;
void add1(int &x,int l,int r,int L,int R,int d)
{
if(!x) x=++cnt,p[d].pb(node{l,r});
if(L<=l && r<=R) return ;
int mid=(l+r)>>1;
if(L<=mid) add1(ls[x],l,mid,L,R,d+1);
if(R>mid) add1(rs[x],mid+1,r,L,R,d+1);
}
void add2(int l,int r,int L,int R,int d)
{
if(L<=l && r<=R) {q[d].pb(node{l,r});return ;}
int mid=(l+r)>>1;
if(L<=mid) add2(l,mid,L,R,d+1);
if(R>mid) add2(mid+1,r,L,R,d+1);
}
void solve()
{
for(int i=1;i<=cnt;i++) ls[i]=rs[i]=0;
for(int i=0;i<=60;i++) p[i].clear(),q[i].clear();
cnt=rt=0;
for(int i=1;i<=n;i++) add1(rt,0,inf,a[i].l,a[i].r,0);
for(int i=1;i<=m;i++) add2(0,inf,b[i].l,b[i].r,0);
for(int w=0;w<=60;w++)
for(auto x:p[w]) for(auto y:q[w])
{
int low=y.l^y.r;
v.pb(node{(x.l^y.l)&(~low),(x.l^y.l)|low});
}
}
int calc(int l,int r)
{
return ((l+r)%MOD)*((r-l+1)%MOD)%MOD*inv2%MOD;
}
signed main()
{
n=read();
for(int i=1;i<=n;i++)
a[i].l=read(),a[i].r=read();
m=read();
for(int i=1;i<=m;i++)
b[i].l=read(),b[i].r=read();
solve();swap(n,m);swap(a,b);solve();
if(v.empty()) {puts("0");return 0;}
sort(v.begin(),v.end());int l=0,r=0;
for(auto x:v)
{
if(x.l<=r+1) r=max(r,x.r);
else ans=(ans+calc(l,r))%MOD,l=x.l,r=x.r;
}
ans=(ans+calc(l,r))%MOD;
printf("%lld\n",ans);
}