题目链接
https://atcoder.jp/contests/agc036/tasks/agc036_e
题解
看了题解第一句话之后意识到这题是sb题以及我又双叒叕智障了……
首先去掉串中相邻的相同字符,不妨设A,B,C
出现次数依次递增。用\(cnt\)来表示一个字符的出现次数。
若\(cnt(B)=cnt(C)\), 那么可以删去所有BC
子串直至\(cnt(A)=cnt(B)=cnt(C)\),同时保证任何两个A
之间不能删空。可以证明一定能做到,因为最坏情况删完后是ABCABC...ABCA
这样。长度上界\(3\cdot cnt(A)\)可以达到。
若\(cnt(B)<cnt(C)\), 首先我们尽量不改变A
和B
而去减少C
. 考虑相邻两个A
之间一定是B
、C
交替,有以下几种类型:
(1) 开头结尾都是B
. 这种情况下我们什么都不能做。
(2) 开头结尾一B
一C
. 这种情况下我们可以删去开头或结尾的C
,把\(cnt(C)\)减少\(1\).
(3) 开头结尾都是C
. 如果这一段只有一个字符C
且不在整个串的首尾, 我们什么都不能做,否则删去首尾的C
,把\(cnt(C)\)减少\(2\).
这样做完后,有可能依然\(cnt(B)<cnt(C)\). 这时任何相邻两个A
之间要么是C
, 要么是BCBCBC...BCB
.
我们的目标是把\(cnt(C)-cnt(B)\)缩小到\(0\), 删后面那种显然是不优的,于是只能删去\((cnt(C)-cnt(B))\)个AC
. 这样显然是合法且最优的。
时间复杂度\(O(n)\).
写起来细节还挺多……
代码
#include<bits/stdc++.h>
#define llong long long
#define mkpr make_pair
#define riterator reverse_iterator
using namespace std;
inline int read()
{
int x = 0,f = 1; char ch = getchar();
for(;!isdigit(ch);ch=getchar()) {if(ch=='-') f = -1;}
for(; isdigit(ch);ch=getchar()) {x = x*10+ch-48;}
return x*f;
}
const int N = 1e6;
char a[N+3],b[N+3]; bool f[N+3];
char permu[3],permui[3]; int cnt[3],cnt2[3];
int n,n0;
bool cmp_cnt(int x,int y) {return cnt[x]<cnt[y];}
void reset()
{
n0 = 0; for(int i=1; i<=n; i++) {if(!f[i]) b[++n0] = a[i];}
n = n0; for(int i=1; i<=n; i++) a[i] = b[i];
}
int main()
{
scanf("%s",a+1); n = strlen(a+1);
for(int i=1; i<=n; i++) a[i] -= 'A';
memset(f,false,sizeof(f));
for(int i=2; i<=n; i++) if(a[i]==a[i-1]) f[i] = true;
reset();
for(int i=1; i<=n; i++) cnt[a[i]]++;
for(int i=0; i<3; i++) permu[i] = i;
sort(permu,permu+3,cmp_cnt); sort(cnt,cnt+3);
for(int i=0; i<3; i++) permui[permu[i]] = i;
for(int i=1; i<=n; i++) a[i] = permui[a[i]];
int dif = cnt[2]-cnt[1];
memset(f,false,sizeof(f));
for(int l=(a[1]==0?2:1); l<=n;)
{
int r = l;
while(r<n&&a[r+1]!=0) {r++;}
if(dif>0&&!(l==r&&l>1&&r<n)&&!(a[l]==1&&a[r]==1))
{
if(a[l]==2&&dif>0) {f[l] = true; dif--;}
if(r>l&&a[r]==2&&dif>0) {f[r] = true; dif--;}
}
l = r+2;
}
reset();
if(dif>0)
{
memset(f,false,sizeof(f));
for(int l=(a[1]==0?2:1); l<=n;)
{
int r = l;
while(r<n&&a[r+1]!=0) {r++;}
if(l==r&&a[l]==2&&l>1&&r<n&&dif>0) {f[l] = f[l-1] = true; dif--;}
l = r+2;
}
reset();
}
memset(cnt,0,sizeof(cnt));
for(int i=1; i<=n; i++) cnt[a[i]]++;
dif = cnt[1]-cnt[0];
memset(f,false,sizeof(f));
for(int l=(a[1]==0?2:1); l<=n;)
{
int r = l;
while(r<n&&a[r+1]!=0) {r++;}
int rst = r-l+1;
for(int i=l+1; i<=r; i++)
{
if(a[i]==2&&a[i-1]==1&&dif>0&&(l==1||r==n||rst>2)) {f[i] = f[i-1] = true; dif--; rst-=2;}
}
l = r+2;
}
reset();
for(int i=1; i<=n; i++) printf("%c",permu[a[i]]+'A'); puts("");
return 0;
}