NOIP模拟 三元子序列计数
题意
给一个长度为 \(n\) 的排列 \(a\),和一个 \(3\) 的排列 \(p\)。求问 \(a\) 有多少长度为 \(3\) 的子序列,满足将其中的元素从小到大编号后为 \(p\)。
思路
仔细手玩一下会发现很难找到一个对于任意 \(p\) 的通解,实际上 \(p\) 的情况可以做一些合并:
原 \(p\) | 归约方法(对于 \(a\) 的变换) | 归约至 \(p'\) |
---|---|---|
\(1,2,3\) | 不需要归约 | \(1,2,3\) |
\(1,3,2\) | 不需要归约 | \(1,3,2\) |
\(2,1,3\) | \(a_i\) 变为 \((n+1)-a_i\),\(a\) 再倒序 | \(1,3,2\) |
\(2,3,1\) | \(a\) 变为倒序 | \(1,3,2\) |
\(3,1,2\) | \(a_i\) 变为 \((n+1)-a_i\) | \(1,3,2\) |
\(3,2,1\) | \(a\) 变为倒序 | \(1,2,3\) |
因此,问题变为了只需要针对 \((1,2,3)\) 和 \((1,3,2)\) 的情况求解。
- 求数列中大小顺序为 \((1,2,3)\) 的子序列,这是一个很典型的思路,我们遍历子序列的中心点,并用树状数组 \(O(\log n)\) 动态维护该点左右边的数,每次遍历到一个点 \(O(\log n)\) 查询该点左边比它小的和右边比他大的数有多少个,相乘即可得到以该点为中心点的满足大小顺序 \((1,2,3)\) 的子序列数。最后相加即可,整体时间复杂度 \(O(n\log n)\)。
- 求数列中大小顺序为 \((1,3,2)\) 的子序列,经尝试后会发现很难用如上遍历中心点的思想来维护,此时其实可以考虑“减法原理”,即可以先求出 \((1,2,3)\) 和 \((1,3,2)\) 的子序列数之和之后再减去 \((1,2,3)\) 的子序列数。而求 \((1,2,3)\) 和 \((1,3,2)\) 的子序列数之和即是求中间右边大于左边的三元子序列数,此时只需要遍历最左侧点,树状数组动态维护右侧比该点大的数的个数,记为 \(k\),则符合要求的三元子序列显然为 \(\frac{k(k-1)}{2}\) 个。
代码
#include<bits/stdc++.h>
using namespace std;
const int MAXN=2e5+5;
typedef long long LL;
int n,p[4],a[MAXN];
class Fenwick{
private:
int t[MAXN];
inline int lowbit(int x){
return x&(-x);
}
public:
inline void init(){
memset(t,0,sizeof(t));
}
Fenwick(){
init();
}
inline void add(int pos,int val){
while(pos<=n){
t[pos]+=val;
pos+=lowbit(pos);
}
}
inline int ask(int pos){
int res=0;
while(pos>0){
res+=t[pos];
pos-=lowbit(pos);
}
return res;
}
inline int ask_range(int l,int r){
if(l>r || r==0) return 0;
if(l==0) return ask(r);
return ask(r)-ask(l-1);
}
};
inline LL solve_123(){
LL res=0;
static Fenwick pre,sub;
for(int i=3;i<=n;i++){
sub.add(a[i],1);
}
pre.add(a[1],1);
for(int i=2;i<=n-1;i++){
res+=1LL*pre.ask_range(1,a[i]-1)*sub.ask_range(a[i]+1,n);
pre.add(a[i],1);
sub.add(a[i+1],-1);
}
return res;
}
inline LL solve_132(){
LL res=0;
static Fenwick sub;
for(int i=2;i<=n;i++){
sub.add(a[i],1);
}
for(int i=1;i<=n-1;i++){
LL ret=sub.ask_range(a[i]+1,n);
if(ret>=2) res+=ret*(ret-1)/2;
sub.add(a[i+1],-1);
}
return res-solve_123();
}
int main(){
cin>>n;
for(int i=1;i<=3;i++){
cin>>p[i];
}
for(int i=1;i<=n;i++){
cin>>a[i];
}
if(n<3){
puts("0");
return 0;
}
if(p[1]==1 && p[2]==2 && p[3]==3){
cout<<solve_123()<<endl;
}
else if(p[1]==1 && p[2]==3 && p[3]==2){
cout<<solve_132()<<endl;
}
if(p[1]==2 && p[2]==1 && p[3]==3){
for(int i=1;i<=n;i++){
a[i]=n+1-a[i];
}
reverse(a+1,a+1+n);
cout<<solve_132()<<endl;
}
else if(p[1]==2 && p[2]==3 && p[3]==1){
reverse(a+1,a+1+n);
cout<<solve_132()<<endl;
}
else if(p[1]==3 && p[2]==1 && p[3]==2){
for(int i=1;i<=n;i++){
a[i]=n+1-a[i];
}
cout<<solve_132()<<endl;
}
else if(p[1]==3 && p[2]==2 && p[3]==1){
reverse(a+1,a+1+n);
cout<<solve_123()<<endl;
}
return 0;
}