CF798E 题解
Solution
挺有意思的,我们不难想到,我们可以通过 \(a_{1,2,..,n}\) 建立起大小关系,为了方便,我们从小向大连,然后通过 topo 序来确定 \(p_{1,2,...,n}\)。连边的话显然有 \(\Theta(n^2)\) 的做法,即 \(x:1\to n\),每一次如果有 \(a_x\),那么 \([1,a_x)\) 中除了未被标记的点都连向 \(x\),\(x\) 连向 \(a_x\)。
我们定义 \(b_{a_x}=x\),那么注意到,如果我们每次能够快速找到未经过的反边,然后暴力删除,总复杂度就是对的,而我们注意到 \(x\) 会连出去的反边只有 \((x,b_x)\) 以及 \((x,i)\) 使得 \(1\le i<a_x\wedge b_i>x\) 的 \(i\),所以我们可以维护区间 \(b_{i}\) 最大值的位置,然后暴力删除就好了。
复杂度 \(\Theta(n\log n)\)。
Code
#include <bits/stdc++.h>
using namespace std;
#define Int register int
#define MAXN 500005
//char buf[1<<21],*p1=buf,*p2=buf;
//#define getchar() (p1==p2 && (p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
template <typename T> void read (T &x){char c = getchar ();x = 0;int f = 1;while (c < '0' || c > '9') f = (c == '-' ? -1 : 1),c = getchar ();while (c >= '0' && c <= '9') x = x * 10 + c - '0',c = getchar ();x *= f;}
template <typename T,typename ... Args> void read (T &x,Args& ... args){read (x),read (args...);}
template <typename T> void write (T x){if (x < 0) x = -x,putchar ('-');if (x > 9) write (x / 10);putchar (x % 10 + '0');}
template <typename T> void chkmax (T &a,T b){a = max (a,b);}
int n,cnt,p[MAXN],a[MAXN],b[MAXN];
struct node{
int maxv,pos;
node operator + (const node &p)const{
if (maxv > p.maxv) return *this;
else return p;
}
};
struct Segment{
node sum[MAXN << 2];
void build (int x,int l,int r){
if (l == r) return sum[x] = node{b[l],l},void ();
int mid = l + r >> 1;
build (x << 1,l,mid),build (x << 1 | 1,mid + 1,r),sum[x] = sum[x << 1] + sum[x << 1 | 1];
}
node findit (int x,int l,int r,int ql,int qr){
if (l >= ql && r <= qr) return sum[x];
int mid = l + r >> 1;
if (qr <= mid) return findit (x << 1,l,mid,ql,qr);
else if (ql > mid) return findit (x << 1 | 1,mid + 1,r,ql,qr);
else return findit (x << 1,l,mid,ql,qr) + findit (x << 1 | 1,mid + 1,r,ql,qr);
}
void del (int x,int l,int r,int pos){
if (l == r) return b[l] = 0,sum[x] = node{0,l},void ();
int mid = l + r >> 1;
if (pos <= mid) del (x << 1,l,mid,pos);
else del (x << 1 | 1,mid + 1,r,pos);
sum[x] = sum[x << 1] + sum[x << 1 | 1];
}
}tree;
void dfs (int x){
int k = b[x];
tree.del (1,1,n,x);
if (k != n + 1 && b[k]) dfs (k);
while (1){
node it = tree.findit (1,1,n,1,a[x]);
if (it.maxv <= x) break;
dfs (it.pos);
}
p[x] = ++ cnt;
}
signed main(){
read (n);
for (Int i = 1;i <= n;++ i){
read (a[i]);
if (~a[i]) b[a[i]] = i;
else a[i] = n + 1;
}
for (Int i = 1;i <= n;++ i) if (!b[i]) b[i] = n + 1;
tree.build (1,1,n);
for (Int i = 1;i <= n;++ i) if (!p[i]) dfs (i);
for (Int i = 1;i <= n;++ i) write (p[i]),putchar (' ');putchar ('\n');
return 0;
}