「THUPC 2022 初赛」搬砖 nsqrtn 做法
抽象一下问题,设一个变量 \(x\),初始为 \(0\),每次令 \(x\gets x+d\),并同时将 \(d\) 减去砖数在 \((x,x+d]\) 内的砖的代价,如果没有砖就停止,询问一共可以跳几步。
定义:"零砖" 为代价为 \(0\) 的砖,"非零砖" 为代价 \(>0\) 的砖。
若 \(d>\sqrt n\),则最多跳 \(\sqrt n\) 步,所以可以这一部分可以暴力跳,每次要查询一个区间和,用 \(\mathcal{O}(\sqrt n)\) 单点修改,\(\mathcal{O}(1)\) 区间查询的分块维护即可,这一部分总复杂度就被平衡到了 \(\mathcal{O}(n\sqrt n)\)。
若 \(d\leq \sqrt n\),则 \(d\) 最多减少 \(\sqrt n\) 次,若每次能 \(\mathcal{O}(x)\) 定位 \(d\) 在哪里减少了(或者跳了一个空的区间导致结束),然后在那里暴力跳,这一部分复杂度就是 \(\mathcal{O}(x\sqrt n)\).
假设我们现在以 \(d\) 为块长,与 \((x+1)\) 模 \(d\) 意义下同余的位置作为块起始端点,进行分块。
定位到 \(d\) 在哪里减少了,相当于查询当前位置后面第一个非零砖的位置,然后算出他所在的块即可。这一部分相当于 01 序列,单点置 1,查询一个位置后面第一个 1,由于修改有 \(\mathcal{O}(n)\) 次,查询有 \(\mathcal{O}(n\sqrt n)\) 次,可以用分块来平衡复杂度做到 \(\mathcal{O}(\sqrt n)\) 修改 \(\mathcal{O}(1)\) 查询。
看是否会跳一个空区间导致结束,相当于看当前这种分块方案下后面第一个不合法的块(合法定义为块里面有砖)是哪一个。注意到是按照块的端点模 \(d\) 的余数分类的,所以所有分块方案的总块数是 \(\sum d\times \frac{n}{d}=\mathcal{O}(n\sqrt n)\).
那么在每一次加入砖的时候,设 \(t\) 为新变成合法块的块数,若能在 \(\mathcal{O}(t)\) 的复杂度内找到遍历所有的新合法块,那么要解决的问题就是 01 序列,单点置 1,查询一个位置后面第一个 0(也就是一个块新变成合法,或者查询一个位置后面第一个不合法块),这个可以并查集维护。
如何 \(\mathcal{O}(t)\) 找到所有新合法块?在每次新加入一个砖的时候,找到它前面第一个砖,后面第一个砖,然后枚举块长 \(d\),就很容易枚举出来了。
综上所述,总复杂度为 \(\mathcal{O}(n\sqrt n\alpha(n))\),采用线性树上并查集即可做到时间复杂度 \(\mathcal{O}(n\sqrt n)\).
代码能力太弱了,我的常数有点大,好像跑得不是很快...
//O(n sqrtn a(n))
#include<cstdio>
#include<vector>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<set>
#define pb emplace_back
#define mp std::make_pair
#define fi first
#define se second
#define int long long
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int>pii;
typedef pair<ll,int>pli;
typedef pair<ll,ll>pll;
typedef vector<int>vi;
typedef vector<ll>vll;
typedef vector<pii>vpii;
template<typename T>T cmax(T &x, T y){return x=x>y?x:y;}
template<typename T>T cmin(T &x, T y){return x=x<y?x:y;}
template<typename T>
T &read(T &r){
r=0;bool w=0;char ch=getchar();
while(ch<'0'||ch>'9')w=ch=='-'?1:0,ch=getchar();
while(ch>='0'&&ch<='9')r=r*10+(ch^48),ch=getchar();
return r=w?-r:r;
}
template<typename T1,typename... T2>
void read(T1 &x, T2& ...y){ read(x); read(y...); }
const int N=100010;
const int K=100000;
const int B=500;
const int M=510;
inline int lowbit(int x){return x&(-x);}
struct Block{
ll sum[N],suf[N];
void modify(int x,int v){
int b=(x-1)/B;
for(int i=b;i<=(K-1)/B;i++)sum[i]+=v;
for(int i=x;i>b*B;i--)suf[i]+=v;
}
ll sumq(int x){
return x<=0 ? 0 : sum[(x-1)/B]-(x%B==0?0:suf[x+1]);
}
ll query(int x,int y){
cmin(y,K);
if(x>y)return 0;
return sumq(y)-sumq(x-1);
}
}tr1,tr2;
int one[N],zero[N];
int tag[N],va[N];
void ins(int x){
int b=(x-1)/B;
for(int i=b-1;~i;--i)cmin(tag[i],x);
for(int i=x;i>b*B;i--)cmin(va[i],x);
}
int que(int x){
return min(va[x],tag[(x-1)/B]);
}
set<int>zs,rzs;
vi f[M][M];
int find(int p,int q,int x){
return f[p][q][x]!=x?f[p][q][x]=find(p,q,f[p][q][x]):x;
}
void merge(int p,int q,int x,int y){
int fx=find(p,q,x),fy=find(p,q,y);
f[p][q][fx]=fy;
}
int calc(int x,int y){
int len;
if(!y)len=K-(x-1);
else len=K-(y-1);
return (len+x-1)/x;
}
int calc2(int pos,int x,int y){
if(!y)pos-=x-1;
else pos-=y-1;
return (pos-1)/x+1;
}
int solve(int d){
int x=0;
int ans=1;
while(d>B&&x<=K){
if(!tr2.query(x+1,x+d))return ans;
ll tmp=tr1.query(x+1,x+d);
x+=d;
++ans;
if(d<=tmp)return ans-1;
d-=tmp;
}
if(x>K)return ans;
while(x<=K){
int pos1=que(x+1);
int xd=(x+1)%d;
int now=calc2(x+1,d,xd);
int nex=find(d,xd,now);
nex=min(nex,calc2(pos1,d,xd));
ans+=nex-now;
x+=(nex-now)*d;
//
if(!tr2.query(x+1,x+d))return ans;
ll tmp=tr1.query(x+1,x+d);
x+=d;
++ans;
if(d<=tmp)return ans-1;
d-=tmp;
}
return ans;
}
signed main(){
for(int i=0;i<=K+1;i++)tag[i]=va[i]=K+1;
zs.insert(K+1);
rzs.insert(K+1);
for(int i=1;i<=B;i++){
for(int j=0;j<i;j++){
int len=calc(i,j);
f[i][j].resize(len+2);
for(int k=1;k<=len+1;k++)
f[i][j][k]=k;
}
}
int T;read(T);while(T--){
int op;read(op);
if(op==1){
int x,y;read(x,y);
if(!one[x]&&y){
one[x]=1;
ins(x);
}
tr1.modify(x,y);
tr2.modify(x,1);
if(!y){
if(zero[x])continue;
zero[x]=1;
int ne=*zs.lower_bound(x);
int la=K-(*rzs.lower_bound(K-x));
if(la<0)la=0;
zs.insert(x);
rzs.insert(K-x);
for(int d=1;d<=B;d++){
int p=max(la+1,x-d+1);
for(int i=p;i<=x&&i+d<=ne;i++){
int tmp=calc2(i,d,i%d);
merge(d,i%d,tmp,tmp+1);
}
}
}
}
else{
int d;read(d);
cout << solve(d) << '\n';
}
}
return 0;
}