题解 第 k 大查询
看起来需要求出每个数在哪些区间内是第 \(k\) 大
不会求,就炸了
- 关于形如「求序列内所有数左/右边」前 \(k\) 个大于/小于这个数的数:
需要避免从一个数向两边爆扫的时候扫到小于这个数的数
于是对这个序列建立一个双向链表,将整个序列复制下来排序,从小到大枚举
枚举到一个数时,向两边爆扫 \(k\) 个数,然后将这个数删除
每次剩下的数一定是大于当前枚举到的数的,所以复杂度是 \(O(nk)\)
于是求出每个数左右 \(k\) 个大于这个数的数
然后双指针,保证两个指针之间只有 \(k-1\) 个数,然后统计区间个数
所以整体复杂度 \(O(nk)\)
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 500010
#define ll long long
#define fir first
#define sec second
#define make make_pair
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, k;
ll a[N];
namespace task1{
ll ans;
int rot[N], ls[N*16], rs[N*16], cnt[N*16], tot;
#define ls(p) ls[p]
#define rs(p) rs[p]
#define cnt(p) cnt[p]
#define pushup(p) cnt(p)=cnt(ls(p))+cnt(rs(p))
void upd(int& p1, int p2, int tl, int tr, int pos) {
if (!p1) p1=++tot;
if (tl==tr) {cnt(p1)=cnt(p2)+1; return ;}
int mid=(tl+tr)>>1;
if (pos<=mid) upd(ls(p1), ls(p2), tl, mid, pos), rs(p1)=rs(p2);
else upd(rs(p1), rs(p2), mid+1, tr, pos), ls(p1)=ls(p2);
pushup(p1);
}
int query(int p1, int p2, int tl, int tr, int k) {
// cout<<"query: "<<tl<<' '<<tr<<' '<<cnt(p1)<<' '<<cnt(p2)<<' '<<k<<endl;
if (tl==tr) return tl;
int mid=(tl+tr)>>1;
if (cnt(rs(p1))-cnt(rs(p2))>=k) return query(rs(p1), rs(p2), mid+1, tr, k);
else return query(ls(p1), ls(p2), tl, mid, k-(cnt(rs(p1))-cnt(rs(p2))));
}
void solve() {
// cout<<double(sizeof(rs)*3+sizeof(a)*2+sizeof(rot)*4)/1000/1000<<endl;
for (int i=1; i<=n; ++i) upd(rot[i], rot[i-1], 1, n, a[i]); //, cout<<"i: "<<i<<' '<<a[i]<<endl;
for (int i=1; i<=n; ++i) for (int j=i; j<=n; ++j) if (j-i+1>=k) ans+=query(rot[j], rot[i-1], 1, n, k);
printf("%lld\n", ans);
exit(0);
}
}
namespace task2{
ll siz[N], ans;
int ls[N], rs[N], top;
pair<int, int> sta[N];
void dfs(int u) {
siz[u]=1;
if (ls[u]) dfs(ls[u]), siz[u]+=siz[ls[u]];
if (rs[u]) dfs(rs[u]), siz[u]+=siz[rs[u]];
ans+=a[u];
if (ls[u]) ans+=a[u]*siz[ls[u]];
if (rs[u]) ans+=a[u]*siz[rs[u]];
if (ls[u]&&rs[u]) ans+=a[u]*siz[ls[u]]*siz[rs[u]];
}
void solve() {
for (int i=1; i<=n; ++i) {
int k=top;
while (k && sta[k].sec<a[i]) --k;
if (k) rs[sta[k].fir]=i;
if (k<top) ls[i]=sta[k+1].fir;
sta[++k]=make(i, a[i]);
top=k;
}
dfs(sta[1].fir);
printf("%lld\n", ans);
exit(0);
}
}
namespace task{
ll ans;
int pos[N], left[N], right[N], now, ltop, rtop;
struct node{int pre, nxt, val, id; node():pre(-1),nxt(-1){}}e[N];
void solve() {
e[0].nxt=1; e[0].val=INF; e[0].id=0;
e[n+1].pre=n; e[n+1].val=INF; e[n+1].id=n+1;
for (int i=1; i<=n; ++i) {e[i].pre=i-1; e[i].nxt=i+1; e[i].val=a[i]; e[i].id=i; pos[a[i]]=i;}
for (int i=1; i<=n; ++i) {
// cout<<"i: "<<i<<endl;
rtop=0; now=pos[i]; right[0]=pos[i];
for (int j=1; j<=k&&~e[now].nxt; ++j,now=e[now].nxt) right[++rtop]=e[e[now].nxt].id;
ltop=0; now=pos[i]; left[0]=pos[i];
for (int j=1; j<=k&&~e[now].pre; ++j,now=e[now].pre) left[++ltop]=e[e[now].pre].id;
// cout<<"top: "<<ltop<<' '<<rtop<<endl;
// cout<<"left: "; for (int i=0; i<=ltop; ++i) cout<<left[i]<<' '; cout<<endl;
// cout<<"right: "; for (int i=0; i<=ltop; ++i) cout<<right[i]<<' '; cout<<endl;
if (ltop-1+rtop-1>=k-1) {
int pos1=ltop, pos2=k-(ltop-1);
while (pos1>0 && pos2<=rtop) {
ans+=1ll*i*(left[pos1-1]-left[pos1])*(right[pos2]-right[pos2-1]);
// cout<<"+="<<i*(left[pos1-1]-left[pos1])*(right[pos2]-right[pos2-1])<<endl;
--pos1; ++pos2;
}
e[e[pos[i]].pre].nxt=e[e[pos[i]].nxt].id;
e[e[pos[i]].nxt].pre=e[e[pos[i]].pre].id;
}
}
printf("%lld\n", ans);
exit(0);
}
}
signed main()
{
freopen("kth.in", "r", stdin);
freopen("kth.out", "w", stdout);
n=read(); k=read();
for (int i=1; i<=n; ++i) a[i]=read();
// if (k==1) task2::solve();
// else task1::solve();
task::solve();
return 0;
}