题解 铃原露露
赛时死在了一个奇怪的地方
首先有一个枚举 \(x, y\),用 \(z\) ban 掉不合法方案的做法
这样是 \(O(n^2)\) 的
然后优化:可以枚举 \(z\),用分属两个不同子树的点对 \((x, y)\) 更新答案
这个分属两个不同子树看着枚举最大的子树就很多余
然而我只想到枚举一个 \(x\),在其余子树中查后继
这样仍然无法避免对最大子树的遍历,于是寄了
然后看题解,题解说不只要查后继,还要查前驱,于是最大子树不用遍历了
这部分 set 实现就 \(O(n\log^2 n)\) 了
然后是怎么扫描线
发现是区间加,区间历史 0 的个数
可以转化为区间加,区间历史最值及个数
- 关于区间加,区间历史最值及个数:
发现覆盖了整个区间的加法是不影响区间内最小值的个数及取到的位置的
于是打一个加入历史的标记,仅在子区间取到最小值时下传即可
复杂度 \(O(n\log^2n)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 200010
#define fir first
#define sec second
#define pb push_back
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, m;
vector<int> to[N];
int a[N], id[N], head[N], dep[N], lg[N], fa[21][N], ecnt;
struct edge{int to, next;}e[N<<1];
inline void add(int s, int t) {e[++ecnt]={t, head[s]}; head[s]=ecnt;}
void dfs(int u) {
for (int i=1; i<21; ++i)
if (dep[u]>=1<<i) fa[i][u]=fa[i-1][fa[i-1][u]];
else break;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
dep[v]=dep[u]+1;
fa[0][v]=u;
dfs(v);
}
}
int lca(int a, int b) {
if (dep[a]<dep[b]) swap(a, b);
while (dep[a]>dep[b]) a=fa[lg[dep[a]-dep[b]]-1][a];
if (a==b) return a;
for (int i=lg[dep[a]]-1; ~i; --i)
if (fa[i][a]!=fa[i][b])
a=fa[i][a], b=fa[i][b];
return fa[0][a];
}
namespace force{
bool check(int l, int r) {
for (int i=l; i<=r; ++i)
for (int j=i+1; j<=r; ++j) {
int z=lca(id[i], id[j]);
if (a[z]<l||a[z]>r) return 0;
}
return 1;
}
void solve() {
for (int i=1,l,r,ans; i<=m; ++i) {
l=read(); r=read(); ans=0;
for (int j=l; j<=r; ++j)
for (int k=j; k<=r; ++k)
ans+=check(j, k);
printf("%d\n", ans);
}
}
}
namespace task1{
int dlt[N], val[N], sum[N], ans[N], pre[N];
vector<pair<int, int>> add[N], del[N], que[N];
void solve() {
for (int i=1; i<=n; ++i) {
for (int j=i+1; j<=n; ++j) {
int z=lca(id[i], id[j]);
if (a[z]<i) {
add[i].pb({j, n});
del[a[z]+1].pb({j, n});
}
if (a[z]>j) {
add[i].pb({j, a[z]-1});
del[1].pb({j, a[z]-1});
}
}
}
for (int i=1,l,r; i<=m; ++i) {
l=read(); r=read();
que[l].pb({r, i});
}
for (int i=n; i; --i) {
for (int j=i; j<=n; ++j) dlt[j]=0;
for (auto it:add[i]) ++dlt[it.fir], --dlt[it.sec+1];
for (int j=i; j<=n; ++j) val[j]+=(dlt[j]+=dlt[j-1]);
for (int j=i; j<=n; ++j) if (!val[j]) ++sum[j];
for (int j=i; j<=n; ++j) pre[j]=pre[j-1]+sum[j];
for (auto it:que[i]) ans[it.sec]=pre[it.fir];
for (int j=i; j<=n; ++j) dlt[j]=0;
for (auto it:del[i]) --dlt[it.fir], ++dlt[it.sec+1];
for (int j=i; j<=n; ++j) val[j]+=(dlt[j]+=dlt[j-1]);
}
for (int i=1; i<=m; ++i) printf("%d\n", ans[i]);
}
}
namespace task{
set<int> s[N];
#define tl(p) tl[p]
#define tr(p) tr[p]
int dlt[N], val[N], sum[N], ans[N], pre[N];
ll his[N<<2], add_tag[N<<2], his_tag[N<<2];
int tl[N<<2], tr[N<<2], mn[N<<2], cnt[N<<2];
vector<pair<int, int>> add[N], del[N], que[N];
inline void pushup(int p) {
his[p]=his[p<<1]+his[p<<1|1];
mn[p]=min(mn[p<<1], mn[p<<1|1]); cnt[p]=0;
if (mn[p<<1]==mn[p]) cnt[p]+=cnt[p<<1];
if (mn[p<<1|1]==mn[p]) cnt[p]+=cnt[p<<1|1];
}
void spread(int p) {
mn[p<<1]+=add_tag[p]; add_tag[p<<1]+=add_tag[p];
mn[p<<1|1]+=add_tag[p]; add_tag[p<<1|1]+=add_tag[p];
add_tag[p]=0;
int minn=min(mn[p<<1], mn[p<<1|1]);
if (mn[p<<1]==minn) his[p<<1]+=cnt[p<<1]*his_tag[p], his_tag[p<<1]+=his_tag[p];
if (mn[p<<1|1]==minn) his[p<<1|1]+=cnt[p<<1|1]*his_tag[p], his_tag[p<<1|1]+=his_tag[p];
his_tag[p]=0;
}
void build(int p, int l, int r) {
tl(p)=l; tr(p)=r;
if (l==r) return void(cnt[p]=1);
int mid=(l+r)>>1;
build(p<<1, l, mid);
build(p<<1|1, mid+1, r);
pushup(p);
}
void upd(int p, int l, int r, int dat) {
if (l<=tl(p)&&r>=tr(p)) {mn[p]+=dat; add_tag[p]+=dat; return ;}
spread(p);
int mid=(tl(p)+tr(p))>>1;
if (l<=mid) upd(p<<1, l, r, dat);
if (r>mid) upd(p<<1|1, l, r, dat);
pushup(p);
}
void snapshot(int p, int l, int r) {
if (mn[p]) return ;
if (l<=tl(p)&&r>=tr(p)) {his[p]+=cnt[p]; ++his_tag[p]; return ;}
spread(p);
int mid=(tl(p)+tr(p))>>1;
if (l<=mid) snapshot(p<<1, l, r);
if (r>mid) snapshot(p<<1|1, l, r);
pushup(p);
}
ll query(int p, int l, int r) {
if (l<=tl(p)&&r>=tr(p)) return his[p];
spread(p);
int mid=(tl(p)+tr(p))>>1; ll ans=0;
if (l<=mid) ans+=query(p<<1, l, r);
if (r>mid) ans+=query(p<<1|1, l, r);
return ans;
}
void cover(int l1, int r1, int l2, int r2) {
add[r1].pb({l2, r2});
del[l1].pb({l2, r2});
}
void dfs(int u) {
// cout<<"dfs: "<<u<<endl;
if (!to[u].size()) {s[u].insert(a[u]); return ;}
for (auto v:to[u]) dfs(v);
sort(to[u].begin(), to[u].end(), [](int a, int b){return s[a].size()<s[b].size();});
// cout<<"mson: "<<to[u].back()<<endl;
swap(s[u], s[to[u].back()]);
for (auto v:to[u]) if (v!=to[u].back())
for (auto it:s[v]) s[u].insert(it);
for (auto v:to[u]) if (v!=to[u].back()) {
// cout<<"v: "<<v<<endl;
for (auto it:s[v]) s[u].erase(it);
for (auto it:s[v]) {
if (it>a[u]) {
auto tem=s[u].lower_bound(it);
if (tem!=s[u].end()) cover(a[u]+1, it, *tem, n);
if (tem!=s[u].begin() && *--tem>a[u]) cover(a[u]+1, *tem, it, n);
}
else {
auto tem=s[u].lower_bound(it);
if (tem!=s[u].end() && *tem<a[u]) cover(1, it, *tem, a[u]-1);
if (tem!=s[u].begin()) cover(1, *--tem, it, a[u]-1);
}
}
for (auto it:s[v]) s[u].insert(it);
}
s[u].insert(a[u]);
}
void solve() {
dfs(1);
for (int i=1,l,r; i<=m; ++i) {
l=read(); r=read();
que[l].pb({r, i});
}
build(1, 1, n);
for (int i=n; i; --i) {
for (auto it:add[i]) upd(1, it.fir, it.sec, 1); //, cout<<it.fir<<' '<<it.sec<<endl;
snapshot(1, i, n);
for (auto it:que[i]) ans[it.sec]=query(1, 1, it.fir);
for (auto it:del[i]) upd(1, it.fir, it.sec, -1);
}
for (int i=1; i<=m; ++i) printf("%d\n", ans[i]);
}
}
signed main()
{
freopen("b.in", "r", stdin);
freopen("b.out", "w", stdout);
n=read(); m=read();
memset(head, -1, sizeof(head));
for (int i=1; i<=n; ++i) id[a[i]=read()]=i;
for (int i=2,f; i<=n; ++i) add(f=read(), i), to[f].pb(i);
for (int i=1; i<=n; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
dep[1]=1; dfs(1);
// force::solve();
// task1::solve();
task::solve();
return 0;
}