题解 字符串
发现 check 是否合法可以用栈贪心
压入一个元素时若与栈顶相同就弹栈
合法条件是最后栈为空
先考虑怎么 \(O(n^2)\)
发现若维护出 \([1, i]\) 和 \([i+1, n]\) 的栈,则合法仅当两栈完全相同
那么枚举 \(i\),枚举其被换成了什么,枚举合法的 \(j\),此时有 \([1, j]\) 的栈,预处理 \([j+1, n]\) 的栈即可
然后正解:
- 统计序列上点对的贡献的一个常用方式是分治
考虑处理 \(i\in [l, mid], j\in [mid+1, r]\) 的合法方案数
那么尝试处理出 \(i, j\) 替换后 \([1, mid]\) 和 \([mid+1, n]\) 的栈的 hash 值
不用可持久化栈怎么二分呢?
发现做到 \([l, r]\) 时已有了 \([1, l-1]\) 的栈
又发现这个栈很特殊,是支持撤销操作的
那么维护 \([i, mid]\) 的栈(mid 在栈底),每次撤销掉一次压入即可维护出 \([1, i], [i+1, mid]\) 的两个栈了
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define fir first
#define sec second
#define pb push_back
#define ll long long
#define ull unsigned long long
//#define int long long
int n;
char s[N];
namespace force{
char sta[N];
int ans, top;
bool check() {
top=0;
for (int i=1; i<=n; ++i) {
if (top&&sta[top]==s[i]) --top;
else sta[++top]=s[i];
}
return !top;
}
void solve() {
for (int i=1; i<=n; ++i)
for (int j=i+1; j<=n; ++j) if (s[i]!=s[j]) {
swap(s[i], s[j]);
ans+=check();
// if (check()) cout<<i<<' '<<j<<endl;
swap(s[i], s[j]);
}
cout<<ans<<endl;
}
}
namespace task1{
int ans, top, btop;
char sta[N], bkp[N];
const ull base=13131;
ull suf[N], h[N], bkph[N];
inline void push(char c) {
if (top&&sta[top]==c) --top;
else sta[++top]=c, h[top]=h[top-1]*base+c;
}
void solve() {
for (int i=n; i; --i) push(s[i]), suf[i]=h[top];
top=0;
for (int i=1; i<=n; ++i) {
btop=top;
for (int j=1; j<=top; ++j) bkp[j]=sta[j], bkph[j]=h[j];
for (char c='a'; c<='c'; ++c) if (c!=s[i]) {
top=btop;
for (int j=1; j<=top; ++j) sta[j]=bkp[j], h[j]=bkph[j];
push(c);
for (int j=i+1; j<=n; ++j) {
if (s[j]==c) {
// tran s[j](s[j]=c) to s[i]
if (top&&sta[top]==s[i]) ans+=(h[top-1]==suf[j+1]);
else ans+=((h[top]*base+s[i])==suf[j+1]);
}
push(s[j]);
}
}
top=btop;
for (int j=1; j<=top; ++j) sta[j]=bkp[j], h[j]=bkph[j];
push(s[i]);
}
cout<<ans<<endl;
}
}
namespace task{
ll ans;
int tot;
ull pw[N];
const ull base=13131;
map<pair<int, int>, int> id;
map<ull, int> mp[N*2][3][3];
struct stack{
vector<char> sta;
vector<ull> h, rh;
vector<pair<bool, char>> rec;
stack(){h.pb(0); rh.pb(0); sta.pb(0); rec.pb({0, 0});}
inline ull hash(int l, int r) {return h[r]-h[l-1]*pw[r-l+1];}
inline void push(char c) {
if (sta.size()>1&&sta.back()==c) {sta.pop_back(); h.pop_back(); rh.pop_back(); rec.pb({1, c});}
else {sta.pb(c); h.pb(h.back()*base+c); rh.pb(c*pw[rh.size()-1]+rh.back()); rec.pb({0, c});}
}
inline void undo() {
if (rec.back().fir) {sta.pb(rec.back().sec); h.pb(h.back()*base+rec.back().sec); rh.pb(rec.back().sec*pw[rh.size()-1]+rh.back()); rec.pop_back();}
else {sta.pop_back(); h.pop_back(); rh.pop_back(); rec.pop_back();}
}
inline ull hash(int len) {return hash(sta.size()-len, sta.size()-1);}
inline ull rhash(int len) {return rh[len];}
}pre, suf;
void solve1(int l, int r) {
if (l==r) {pre.push(s[l]); return ;}
id[{l, r}]=++tot;
int mid=(l+r)>>1;
stack sta;
// cout<<"solve1: "<<l<<' '<<r<<' '<<pre.sta.size()<<endl;
// assert(pre.sta.size()==l);
for (int i=mid; i>l; --i) sta.push(s[i]);
for (int i=l; i<=mid; ++i) {
// cout<<"i: "<<i<<endl;
for (char j='a'; j<='c'; ++j) if (j!=s[i]) {
// cout<<"j: "<<int(j)<<endl;
pre.push(j);
int tl=0, tr=min(pre.sta.size()-1, sta.sta.size()-1), tmid;
while (tl<=tr) {
tmid=(tl+tr)>>1;
if (!tmid||pre.hash(tmid)==sta.hash(tmid)) tl=tmid+1;
else tr=tmid-1;
}
// acacbc
tmid=tl-1;
// cout<<"tmid: "<<tmid<<endl;
// cout<<"pre: "; for (auto it:pre.sta) cout<<int(it)<<' '; cout<<endl;
// cout<<"sta: "; for (auto it:sta.sta) cout<<int(it)<<' '; cout<<endl;
// cout<<"h: "; for (auto it:pre.h) cout<<it<<' '; cout<<endl;
// cout<<"rh: "; for (auto it:sta.rh) cout<<it<<' '; cout<<endl;
++mp[tot][s[i]-'a'][j-'a'][pre.h[pre.sta.size()-tmid-1]*pw[sta.sta.size()-tmid-1]+sta.rhash(sta.sta.size()-tmid-1)];
// cout<<sta.rhash(sta.sta.size()-tmid)<<endl;
// if (l==1&&r==6 && s[i]=='a' && j=='c') {
// cout<<"i: "<<i<<endl;
// cout<<"tmid: "<<tmid<<endl;
// cout<<"sta: "; for (auto it:sta.sta) cout<<int(it)<<' '; cout<<endl;
// cout<<"hash1: "<<tot<<' '<<s[i]<<' '<<j<<' '<<pre.h[pre.sta.size()-tmid-1]*pw[sta.sta.size()-tmid-1]+sta.rhash(sta.sta.size()-tmid-1)<<endl;
// cout<<"pos: "<<i<<endl;
// cout<<pre.h[pre.sta.size()-tmid-1]<<endl;
// }
pre.undo();
}
pre.push(s[i]); sta.undo();
}
// cout<<"lr: "<<l<<' '<<r<<' '<<pre.sta.size()<<endl;
for (int i=l; i<=mid; ++i) pre.undo();
solve1(l, mid); solve1(mid+1, r);
}
void solve2(int l, int r) {
if (l==r) {suf.push(s[l]); return ;}
tot=id[{l, r}];
int mid=(l+r)>>1;
stack sta;
// assert(suf.sta.size()==n-r+1);
for (int i=mid+1; i<r; ++i) sta.push(s[i]);
for (int i=r; i>mid; --i) {
// if (l==1&&r==n) cout<<"i: "<<i<<endl;
for (char j='a'; j<='c'; ++j) if (j!=s[i]) {
// if (l==1&&r==n) cout<<"j: "<<j<<endl;
suf.push(j);
int tl=0, tr=min(sta.sta.size()-1, suf.sta.size()-1), tmid;
while (tl<=tr) {
tmid=(tl+tr)>>1;
if (!tmid||suf.hash(tmid)==sta.hash(tmid)) tl=tmid+1;
else tr=tmid-1;
}
tmid=tl-1;
ans+=mp[tot][j-'a'][s[i]-'a'][suf.h[suf.sta.size()-tmid-1]*pw[sta.sta.size()-tmid-1]+sta.rhash(sta.sta.size()-tmid-1)];
// if (mp[tot][j-'a'][s[i]-'a'][suf.h[suf.sta.size()-tmid-1]*pw[sta.sta.size()-tmid-1]+sta.rhash(sta.sta.size()-tmid-1)]) cout<<"lr: "<<l<<' '<<r<<' '<<j<<' '<<s[i]<<endl;
// cout<<"hash2: "<<tot<<' '<<j<<' '<<s[i]<<' '<<suf.h[suf.sta.size()-tmid-1]*pw[sta.sta.size()-tmid-1]+sta.rhash(sta.sta.size()-tmid-1)<<endl;
// if (l==1&&r==6 && j=='a' && s[i]=='c') {
// cout<<"i: "<<i<<endl;
// cout<<"tmid: "<<tmid<<endl;
// cout<<"suf: "; for (auto it:suf.sta) cout<<int(it)<<' '; cout<<endl;
// cout<<"sta: "; for (auto it:sta.sta) cout<<int(it)<<' '; cout<<endl;
// cout<<"hash2: "<<tot<<' '<<j<<' '<<s[i]<<' '<<suf.h[suf.sta.size()-tmid-1]*pw[sta.sta.size()-tmid-1]+sta.rhash(sta.sta.size()-tmid-1)<<endl;
// cout<<"pos: "<<i<<endl;
// cout<<"val: "<<mp[tot][j-'a'][s[i]-'a'][suf.h[suf.sta.size()-tmid-1]*pw[sta.sta.size()-tmid-1]+sta.rhash(sta.sta.size()-tmid-1)]<<endl;
// }
suf.undo();
}
suf.push(s[i]); sta.undo();
}
for (int i=mid+1; i<=r; ++i) suf.undo();
solve2(mid+1, r); solve2(l, mid);
}
void solve() {
if (!n) {puts("0"); return ;}
pw[0]=1;
for (int i=1; i<=n; ++i) pw[i]=pw[i-1]*base;
solve1(1, n); solve2(1, n);
printf("%lld\n", ans);
}
}
signed main()
{
freopen("string.in", "r", stdin);
freopen("string.out", "w", stdout);
scanf("%s", s+1);
n=strlen(s+1);
// force::solve();
// task1::solve();
task::solve();
return 0;
}