题解 正睿1333 【省选特训19】交换
题目大意
我们按以下方式定义一个字符串是否合法:
- 空串是合法的;
- 如果 S 是合法的,那么 aSa,bSb,cSc 是合法的;
- 如果 S 和 T 都是合法的,那么 ST 是合法的;
- 不能被以上方式定义为合法的字符串都是不合法的。
给定一个只包含 a,b,c 的字符串 S,求有多少种方案交换两个不同字符使得交换之后 S 合法。
数据范围:\(|S|\leq 10^5\)。
本题题解
以下令\(n=|S|\)。
首先,判断一个串是否合法,按照定义暴力判断(搜索),如果加上记忆化,时间复杂度是\(O(n^3)\)的,相当于一个区间DP。
我们还有更好的方法。考虑每次删掉两个相邻且相同的字符,直到不能删为止。如果剩下的是空串,则原串合法,否则原串不合法。形式化地说,假设对串\(s\)执行这样的删除操作后(直到不能再删为止),得到的串为\(f(s)\)。那么串\(s\)合法当且仅当\(f(s)\)为空串。并且显然,操作的顺序不会影响最终的结果,所以每个\(s\)唯一对应一个\(f(s)\)。
求\(f(s)\),我们可以依次扫描\(s\)的每一位,维护一个栈,如果栈非空且当前位等于栈顶,则把栈顶弹出,否则把当前位入栈。扫描完后,栈里剩下的,就是\(f(s)\)了。用这种方法,可以\(O(n)\)实现“判断一个串是否合法”。
还有一个很奇妙的性质,后面会用到。假设我们已经加入了\(s[1\dots i]\),我们现在要把\(s[i]\)删掉,也就是恢复到\(s[1\dots i-1]\)的栈(\(f(s[1\dots i-1])\)),这也可以实现,并且实现的方法和加入字符一模一样!也就是说:如果栈顶等于\(s[i]\),则把栈顶弹出,否则就把\(s[i]\)入栈(因为这种情况说明\(s[i]\)已经和前面另一个相同的字符一起消掉了,此时如果删除\(s[i]\),另一个被消掉的字符就会重新出现)。
这个栈的性质简洁优美,回到本题,让我们用它来解题。
考虑分治。假设当前分治的区间为\([l,r]\) (\(l<r\))。取中点\(m=\lfloor\frac{l+r}{2}\rfloor\)。那么我们此时只考虑,要交换的两个位置,一个在\([l,m]\),另一个在\([m+1,r]\)的情况。
假设要交换的两个字符是\(\text{a}\)和\(\text{b}\),那么相当于把前半段某个\(\text{a}\)改成\(\text{b}\),把后半段某个\(\text{b}\)改成\(\text{a}\)。
我们现在要对前半段的每个位置\(i\) (\(l\leq i\leq m,S[i]=\text{a}\)),求出:把\(S[i]\)改成\(\text{b}\)后(新串记为\(S'_i\))的\(f(S'_i[1\dots m])\)。同理,再考虑后半段的一个位置\(j\) (\(m<j\leq r,S[j]=\text{b}\)),求出把\(S[j]\)改成\(\text{a}\)后的\(f(S'_j[m+1\dots n])\)。我们知道,交换\(S[i],S[j]\)后得到的新串合法,当且仅当\(f(f(S'_i[1\dots m])+f(S'_j[m+1\dots n]))=\text{空串}\)。这等价于:\(f(S'_i[1\dots m])=\text{reverse}(f(S'_j[m+1\dots n]))\)。所以我们只需要对每个\(i\),求出\(f(S'_i[1\dots m])\)的哈希值;对每个\(j\),求出\(f(S'_j[m+1\dots n])\)的反串的哈希值即可。
依次考虑每个\(i\)。发现:\(f(S'_i[1\dots m])=f(f(S[1\dots i-1])+\text{b}+f(S[i+1\dots m]))\)。根据前面的讨论,\(f(s)\),在\(s\)后面添加/删除一个字符时,都可以\(O(1)\)维护。所以\(f(S[1\dots i-1])\)和\(f(S[i+1\dots m])\)都可以维护出来。那么难点就在于:怎么合并两个\(f\)。
假设我们要合并\(f(A),f(B)\),也就是已知\(f(A),f(B)\),求\(f(f(A)+f(B))\)。因为我们前面提到过维护哈希值,所以此时\(f(A),f(B)\)的哈希值(以及反串的哈希值)也被维护好了(在入栈、出栈时顺便维护一下即可)。那么,利用哈希做判断,我们可以二分出一个最大的\(k\),满足:\(f(A)\)的长度为\(k\)的后缀,它倒过来等于\(f(B)\)的长度为\(k\)的前缀。然后把\(f(A)\)的这段后缀、\(f(B)\)的这段前缀都删掉,再拼起来即可。
于是我们就能对每个\(i\), \(j\),都求出其修改后的哈希值。要判断是否相等,可以用一个\(\texttt{std::map}\),记录所有\(i\)的哈希值,然后拿每个\(j\)的哈希值去\(\texttt{map}\)里查表。因为用到二分和\(\texttt{map}\),所以一次分治的时间复杂度是\(O(\text{len}\log \text{len})\)的 (\(\text{len}=r-l+1\))。那么总时间复杂度就是\(O(n\log^2 n)\)。
当然,以上只讨论了左边的\(\text{a}\)与右边的\(\text{b}\)交换的情况。实际上共有\(3\times2=6\)种等价的情况。也就是说我们大力枚举要交换的两个字符分别是什么。如果设字符集为\(z\)的话,则实际上时间复杂度应该是\(O(n\log^2n\cdot |z|^2)\)。
还有一个要注意的点是,因为是要维护\(f(S'_i[1\dots m])\)(注意,不是\([l\dots m]\)),所以进入一个分治区间时,要维护好一个全局的\(f(S[1\dots l-1])\)。回溯时,注意恢复这个全局的\(f(S[1\dots l-1])\)。右边也是同理。这样才能使单次调用的复杂度只和\(\text{len}\)相关,而不退化到和\(n\)相关。因为我们能支持\(O(1)\)插入和删除,所以这两个全局的东西还是很好维护的。
参考代码:
//problem:ZR1333(C)
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
const int MAXN=1e5;
const ull BASE=131;
ull pw[MAXN+5];
int n;
char s[MAXN+5];
struct Stack{
int s[MAXN+5];
int top;
ull hash[MAXN+5],hash_rev[MAXN+5];
void push(int x){
s[++top]=x;
hash[top]=hash[top-1]*BASE+(ull)s[top];
hash_rev[top]=hash_rev[top-1]+pw[top-1]*(ull)s[top];
}
void pop(){assert(top>0);--top;}
void clear(){top=0;}
void ins(int x){
if(top && s[top]==x)
pop();
else
push(x);
}
void del(int x){ins(x);}
};
Stack pre,suf,tmp_suf,tmp_pre;
ull f[MAXN+5][3],f_rev[MAXN+5][3];
void calc_left(int l,int r){
tmp_suf.clear();
for(int i=r;i>=l;--i)
tmp_suf.ins(s[i]);
for(int i=l;i<=r;++i){
tmp_suf.del(s[i]);
for(int j=0;j<=2;++j)if(j!=s[i]-'a'){
pre.ins(j+'a');
int l=0,r=min(pre.top,tmp_suf.top);
while(l<r){
int mid=(l+r+1)>>1;
ull prehash=pre.hash[pre.top]-pre.hash[pre.top-mid]*pw[mid];
ull sufhash=tmp_suf.hash[tmp_suf.top]-tmp_suf.hash[tmp_suf.top-mid]*pw[mid];
if(prehash==sufhash)l=mid;
else r=mid-1;
}
f[i][j]=pre.hash[pre.top-l]*pw[tmp_suf.top-l]
+tmp_suf.hash_rev[tmp_suf.top-l];
pre.del(j+'a');
}
pre.ins(s[i]);
}
}
void calc_right(int l,int r){
tmp_pre.clear();
for(int i=l;i<=r;++i){
suf.del(s[i]);
for(int j=0;j<=2;++j)if(j!=s[i]-'a'){
tmp_pre.ins(j+'a');
int l=0,r=min(tmp_pre.top,suf.top);
while(l<r){
int mid=(l+r+1)>>1;
ull prehash=tmp_pre.hash[tmp_pre.top]-tmp_pre.hash[tmp_pre.top-mid]*pw[mid];
ull sufhash=suf.hash[suf.top]-suf.hash[suf.top-mid]*pw[mid];
if(prehash==sufhash)l=mid;
else r=mid-1;
}
f_rev[i][j]=tmp_pre.hash_rev[tmp_pre.top-l]
+suf.hash[suf.top-l]*pw[tmp_pre.top-l];
tmp_pre.del(j+'a');
}
tmp_pre.ins(s[i]);
}
}
ll ans;
void solve(int l,int r){
if(l==r)return;
int mid=(l+r)>>1;
//现在考虑要交换的两位置一个在[l,mid],一个在[mid+1,r]的情况
calc_left(l,mid);
for(int i=r;i>mid;--i)suf.ins(s[i]);
calc_right(mid+1,r);
for(int x=0;x<=2;++x){
for(int y=0;y<=2;++y)if(x!=y){
map<ull,int>mp;
for(int i=l;i<=mid;++i)if(s[i]=='a'+x){
mp[f[i][y]]++;
}
for(int i=mid+1;i<=r;++i)if(s[i]=='a'+y){
ans+=(!mp.count(f_rev[i][x])?0:mp[f_rev[i][x]]);
}
}
}
solve(mid+1,r);
for(int i=mid;i>=l;--i)pre.del(s[i]);
for(int i=r;i>mid;--i)suf.ins(s[i]);
solve(l,mid);
for(int i=mid+1;i<=r;++i)suf.del(s[i]);
}
int main() {
cin>>(s+1);n=strlen(s+1);
pw[0]=1;
for(int i=1;i<=n;++i)pw[i]=pw[i-1]*BASE;
solve(1,n);
cout<<ans<<endl;
return 0;
}