题目链接

https://atcoder.jp/contests/agc036/tasks/agc036_e

题解

看了题解第一句话之后意识到这题是sb题以及我又双叒叕智障了……

首先去掉串中相邻的相同字符,不妨设A,B,C出现次数依次递增。用\(cnt\)来表示一个字符的出现次数。
\(cnt(B)=cnt(C)\), 那么可以删去所有BC子串直至\(cnt(A)=cnt(B)=cnt(C)\),同时保证任何两个A之间不能删空。可以证明一定能做到,因为最坏情况删完后是ABCABC...ABCA这样。长度上界\(3\cdot cnt(A)\)可以达到。
\(cnt(B)<cnt(C)\), 首先我们尽量不改变AB而去减少C. 考虑相邻两个A之间一定是BC交替,有以下几种类型:
(1) 开头结尾都是B. 这种情况下我们什么都不能做。
(2) 开头结尾一BC. 这种情况下我们可以删去开头或结尾的C,把\(cnt(C)\)减少\(1\).
(3) 开头结尾都是C. 如果这一段只有一个字符C且不在整个串的首尾, 我们什么都不能做,否则删去首尾的C,把\(cnt(C)\)减少\(2\).
这样做完后,有可能依然\(cnt(B)<cnt(C)\). 这时任何相邻两个A之间要么是C, 要么是BCBCBC...BCB.
我们的目标是把\(cnt(C)-cnt(B)\)缩小到\(0\), 删后面那种显然是不优的,于是只能删去\((cnt(C)-cnt(B))\)AC. 这样显然是合法且最优的。
时间复杂度\(O(n)\).

写起来细节还挺多……

代码

#include<bits/stdc++.h>
#define llong long long
#define mkpr make_pair
#define riterator reverse_iterator
using namespace std;

inline int read()
{
	int x = 0,f = 1; char ch = getchar();
	for(;!isdigit(ch);ch=getchar()) {if(ch=='-') f = -1;}
	for(; isdigit(ch);ch=getchar()) {x = x*10+ch-48;}
	return x*f;
}

const int N = 1e6;
char a[N+3],b[N+3]; bool f[N+3];
char permu[3],permui[3]; int cnt[3],cnt2[3];
int n,n0;

bool cmp_cnt(int x,int y) {return cnt[x]<cnt[y];}

void reset()
{
	n0 = 0; for(int i=1; i<=n; i++) {if(!f[i]) b[++n0] = a[i];}
	n = n0; for(int i=1; i<=n; i++) a[i] = b[i];
}

int main()
{
	scanf("%s",a+1); n = strlen(a+1);
	for(int i=1; i<=n; i++) a[i] -= 'A';
	memset(f,false,sizeof(f));
	for(int i=2; i<=n; i++) if(a[i]==a[i-1]) f[i] = true;
	reset();
	
	for(int i=1; i<=n; i++) cnt[a[i]]++;
	for(int i=0; i<3; i++) permu[i] = i;
	sort(permu,permu+3,cmp_cnt); sort(cnt,cnt+3);
	for(int i=0; i<3; i++) permui[permu[i]] = i;
	for(int i=1; i<=n; i++) a[i] = permui[a[i]];
	
	int dif = cnt[2]-cnt[1];
	memset(f,false,sizeof(f));
	for(int l=(a[1]==0?2:1); l<=n;)
	{
		int r = l;
		while(r<n&&a[r+1]!=0) {r++;}
		if(dif>0&&!(l==r&&l>1&&r<n)&&!(a[l]==1&&a[r]==1))
		{
			if(a[l]==2&&dif>0) {f[l] = true; dif--;}
			if(r>l&&a[r]==2&&dif>0) {f[r] = true; dif--;}
		}
		l = r+2;
	}
	reset();
	
	if(dif>0)
	{
		memset(f,false,sizeof(f));
		for(int l=(a[1]==0?2:1); l<=n;)
		{
			int r = l;
			while(r<n&&a[r+1]!=0) {r++;}
			if(l==r&&a[l]==2&&l>1&&r<n&&dif>0) {f[l] = f[l-1] = true; dif--;}
			l = r+2;
		}
		reset();
	}
	
	memset(cnt,0,sizeof(cnt));
	for(int i=1; i<=n; i++) cnt[a[i]]++;
	dif = cnt[1]-cnt[0];
	memset(f,false,sizeof(f));
	for(int l=(a[1]==0?2:1); l<=n;)
	{
		int r = l;
		while(r<n&&a[r+1]!=0) {r++;}
		int rst = r-l+1;
		for(int i=l+1; i<=r; i++)
		{
			if(a[i]==2&&a[i-1]==1&&dif>0&&(l==1||r==n||rst>2)) {f[i] = f[i-1] = true; dif--; rst-=2;}
		}
		l = r+2;
	}
	reset();
	
	for(int i=1; i<=n; i++) printf("%c",permu[a[i]]+'A'); puts("");
	return 0;
}