AtCoder Grand Contest 036E
题目链接在这里
题目大意:给定一个长度\(\leq10^6\)的仅含\(A,B,C\)的字符串,求它的一个最长的子序列,使得\(A,B,C\)数量相同并且相邻的两个字符不同.给出一种方案.
好久不更了,就随便更一点
智商过低根本想不出来这种过于神仙的题...
首先有一个显然的结论就是我们把两个相邻的相同字符当做一个字符.
设\(A\)的数量为\(p\),\(B\)的数量为\(q\),\(C\)的数量为\(r\),\(p\leq q\leq r\)
如果现在\(q=r\),我们找出两个\(A\)中间夹着的串,然后不停地删去两个连续的\(BC\)(或\(CB\)),并保证\(A\)不相邻.正确性是显然的.
如果\(q<r\),定义\(2\)个变量\(b,c\),其中(这里我们假设开头之前与结尾之后都是\(A\)):
- \(b\)表示形如\(A..B..A\)的串的数量.
- \(c\)表示形如\(ACA\)的串的数量.
注意这些串除了开头结尾均不含\(A\)
假设我们已经找到了答案,并计算出了答案对应的\(2\)个变量.如果\(b<c\),显然不可能达到要求,因为所有\(b\)上\(q-r\)的值至多是\(1\),而所有\(c\)串上\(q-r\)的值必定是\(-1\).但是,如果\(c\leq b\),我们就可以仅删除\(C\)来达到要求.只要枚举每一段,如果\(r>q\)就让它都变成\(BC...BCB\)的形式,即这一段的\(B\)个数比\(C\)个数多\(1\)就可以了.我们可以容易证明这样\(q\)一定会和\(r\)相同.
但是如果我们算出初始串的\(b<c\)怎么办呢?这样所有的\(A\)就没办法用了.我们考虑枚举所有的\(c\)类型的子串,删去它右边的\(A\),那么\(c\)会减一,而其它两个个值都不会变.我们只要让\(c\)变成\(b\),也就是删除\(c-b\)个\(c\)串末尾的\(A\)就可以了.这样就再次转化为\(c\leq b\)的形式.
最后,它一定会变成一个\(p<q=r\)的形式,因此我们像最早所述删去连续的数个\(BC\)即可.
只需模拟上述过程.时间复杂度\(O(n)\).
代码如下:
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<bitset>
#include<set>
#define N (5000010)
#define P ()
#define M ()
#define inf (0x7f7f7f7f)
#define rg register int
#define Label puts("NAIVE")
#define spa print(' ')
#define ent print('\n')
#define rand() (((rand())<<(15))^(rand()))
#define file(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout)
typedef long double ld;
typedef long long LL;
typedef unsigned long long ull;
using namespace std;
int n,c1,c2,c3,p,q,now,tar;
char s[N],str[N],w1,w2,w3;
vector<int> pos2; bool ban[N];
int gs(int x){
if(s[x]==w1)return 1;else if(s[x]==w2)return 2;else return 3;
}
void getstr(int L,int R){
if(L>R)return;
int ch[4]={0,0,0,0};
for(int i=L;i<=R;i++)ch[gs(i)]++;
if(!ch[2]&&L>1&&R<n)q++,pos2.push_back(L-1);
if(ch[2])p++;
}
void del(){
int tmp=0,lst=-1;
for(int i=1;i<=n;i++)
if(!ban[i]&&s[i]!=lst)lst=s[i],s[++tmp]=s[i];
n=tmp;
}
void cal(){
c1=c2=c3=0;
for(int i=1;i<=n;i++)
if(s[i]==w1)c1++;else if(s[i]==w2)c2++;else c3++;
memset(ban,0,sizeof(ban));
}
void solve(int L,int R){
if(L>R)return;
if(L==R){
if(L==1||L==n){if(s[L]==w3&&c3>c2)c3--,ban[L]=1;}
return;
}
if(s[L]==w3&&c3>c2)c3--,ban[L]=1;
if(s[R]==w3&&c3>c2)c3--,ban[R]=1;
}
void solve2(int L,int R){
if(L>R)return;
while(L<R&&now>tar){
if(L+1>=R&&L!=1&&R!=n)break;
now--,ban[L]=ban[L+1]=1,L+=2;
}
}
int main(){
scanf("%s",str+1),w1='A',w2='B',w3='C';
for(int i=1;i<=strlen(str+1);i++)
if(str[i]!=str[i-1]){
s[++n]=str[i];
if(s[n]==w1)c1++;else if(s[n]==w2)c2++;else c3++;
}
if(c1>c2)swap(c1,c2),swap(w1,w2);
if(c1>c3)swap(c1,c3),swap(w1,w3);
if(c2>c3)swap(c2,c3),swap(w2,w3);
int j=0;
for(int i=1;i<=n;i++)
if(s[i]==w1)getstr(j+1,i-1),j=i; getstr(j+1,n);
if(p<q)for(int i=0;i<q-p;i++)ban[pos2[i]]=1,c1--;
del(),cal(),j=0;
for(int i=1;i<=n;i++)
if(s[i]==w1)solve(j+1,i-1),j=i; solve(j+1,n);
del(),cal(),j=0,now=c2,tar=c1;
for(int i=1;i<=n;i++)
if(s[i]==w1)solve2(j+1,i-1),j=i; solve2(j+1,n);
for(int i=1;i<=n;i++)
if(!ban[i])printf("%c",s[i]);
}