题解 字符串

传送门

剩 1h10min 的时候发现自己做法假了,疯狂修锅
剩 40min 的时候终于把问题转化成了一个自己不会的板子,于是就爆零了/kk

首先发现本质不同子串个数是 \(O(n^2)\) 级别的,所以这棵树并不能建出来
但发现只有黑色点是有用的,考虑建出一棵压缩后缀树
发现这个东西就是对正串建 SAM 得到的 fail 树,那么只需考虑怎么求区间本质不同子串数
然后发现可以做到 \(O(n\log^2 n)\),需要大力卡常

然后

然后你发现经过亿些优化,你已经可以在本机 1s 内跑完 1e6
但是在 TLEcoders 交上去 2s 的时限还是不能跑动
你奋力挣扎,使尽浑身解数优化操作次数
然而,摆在你面前的还是一个冰冷无情的
image
你终于放弃了
但是,在永远的关闭这个页面以前,你想看看卡过去的人是怎么写的
你发现了一份卡过的代码,复制到本地,编译运行
你发现这份码在随机数据下跑了 3s
你懵了。
你去要数据,并神奇的发现在官方数据下这份码只跑了 0.4s
你愈发疑惑,开始逐行阅读这份充满了谜团的代码
终于,你在最后一行看到了这样一个句子:

for (int i = 1, lst = 0; i <= m; ++i) {
	while (lst < q[i].r) ++lst, Lct.access(pos[lst], lst);
	res += que(q[i].l, q[i].r);
}

看起来这样插入的数有时可以少于 \(n\)
那么,在官方数据 \(n=1e6\) 下这份代码会进行多少次插入呢?
你打开了终端

g++ ./std.cpp -o std -O2 && time ./std && cat std.out

你看着面前的输出

lst: 131072
7822239348748996

你又看了一遍
然后发现你插入 \(5e5\) 个数的用时和这份码插入 \(1.3e5\) 个数的用时是相同的
但是,它过了,你 T 了
你陷入癫狂,你原样抄下这个剪枝,用颤抖的指尖按下了 submit 键
然后
再然后
image

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f3f
#define N 2000010
#define fir first
#define sec second
#define pb push_back
#define ll long long
#define ull unsigned long long
//#define int long long

int n;
char s[N];
bool vis[N];
int tr[N][2], back[N], tot;
inline void chkmin(int& a, int b) {a=min(a, b);}
// int ins(char* c, int p=0) {
// 	for (int *t; ; p=*t) {
// 		t=&tr[p][*c-'0'];
// 		if (!*t) back[*t=++tot]=p;
// 		if (!*(++c)) {vis[*t]=1; return *t;}
// 	}
// }
// void dfs1(int u) {
// 	int cnt=0;
// 	if (tr[u][0]) dfs1(tr[u][0]), ++cnt;
// 	if (tr[u][1]) dfs1(tr[u][1]), ++cnt;
// 	if (cnt!=1) vis[u]=1;
// }

// namespace force{
// 	ll ans;
// 	int top;
// 	char sta[N];
// 	const ull base=13131;
// 	unordered_map<ull, bool> mp;
// 	ll calc() {
// 		// cout<<"calc: "; for (int i=1; i<=top; ++i) cout<<sta[i]; cout<<endl;
// 		mp.clear();
// 		for (int i=1; i<=top; ++i) {
// 			ull h=0;
// 			for (int j=i; j<=top; ++j) mp[h=h*base+sta[j]]=1;
// 		}
// 		return mp.size();
// 	}
// 	void dfs2(int u) {
// 		for (int i=0; i<2; ++i) if (tr[u][i]) {
// 			sta[++top]='0'+i;
// 			if (vis[tr[u][i]]) {ans+=calc(); top=0;}
// 			dfs2(tr[u][i]);
// 		}
// 	}
// 	void solve() {
// 		for (int i=1; i<=n; ++i) ins(s+i);
// 		vis[0]=1; dfs1(0);
// 		dfs2(0);
// 		printf("%lld\n", ans);
// 	}
// }

// namespace sufsort{
// 	vector<pair<string, int>> sta;
// 	void solve() {
// 		for (int i=1; i<=n; ++i) {
// 			string t;
// 			for (int j=i; j<=n; ++j) t.pb(s[j]);
// 			sta.pb({t, i});
// 		}
// 		sort(sta.begin(), sta.end());
// 		for (auto it:sta) cout<<it.sec<<' '; cout<<endl;
// 		cout<<0<<' ';
// 		for (int i=1; i<sta.size(); ++i) {
// 			string s=sta[i-1].fir, t=sta[i].fir;
// 			int h=0;
// 			while (h<s.length()&&h<t.length()&&s[h]==t[h]) ++h;
// 			cout<<h<<' ';
// 		} cout<<endl;
// 	}
// }

// namespace task1{
// 	ll ans;
// 	int top;
// 	char sta[N];
// 	ull h[N], pw[N];
// 	const ull base=13131;
// 	unordered_map<ull, bool> mp;
// 	int sa[N], id[N], rk[N], oldrk[N<<1], px[N], cnt[N], ht[N], m=256;
// 	inline bool cmp(int a, int b, int w) {return oldrk[a]==oldrk[b]&&oldrk[a+w]==oldrk[b+w];}
// 	inline ull hashing(int l, int r) {return h[r]-h[l-1]*pw[r-l+1];}
// 	namespace sam{
// 		ll ans;
// 		int len[N], fail[N], tr[N][2], now, tot;
// 		void init() {fail[now=tot=ans=0]=-1; tr[0][0]=tr[0][1]=0;  }
// 		void ins(char c) {
// 			c-='0';
// 			int cur=++tot;
// 			fail[cur]=tr[cur][0]=tr[cur][1]=0;
// 			len[cur]=len[now]+1;
// 			int p, q;
// 			for (p=now; ~p&&!tr[p][c]; tr[p][c]=cur,p=fail[p]);
// 			if (p==-1) fail[cur]=0;
// 			else if (len[q=tr[p][c]]==len[p]+1) fail[cur]=q;
// 			else {
// 				int cln=++tot;
// 				len[cln]=len[p]+1;
// 				fail[cln]=fail[q];
// 				tr[cln][0]=tr[q][0], tr[cln][1]=tr[q][1];
// 				for (; ~p&&tr[p][c]==q; tr[p][c]=cln,p=fail[p]);
// 				fail[cur]=fail[q]=cln;
// 			}
// 			now=cur;
// 			ans+=len[cur]-len[fail[cur]];
// 		}
// 	}
// 	ll calc() {
// 		// cout<<"calc: "; for (int i=1; i<=top; ++i) cout<<sta[i]; cout<<endl;
// 		sam::init();
// 		for (int i=1; i<=top; ++i) sam::ins(sta[i]);
// 		// cout<<"len: "; for (int i=1; i<=sam::tot; ++i) cout<<sam::len[i]<<' '; cout<<endl;
// 		// cout<<"fail: "; for (int i=1; i<=sam::tot; ++i) cout<<sam::fail[i]<<' '; cout<<endl;
// 		// cout<<"return: "<<sam::ans<<endl;
// 		return sam::ans;
// 	}
// 	void dfs2(int u) {
// 		for (int i=0; i<2; ++i) if (tr[u][i]) {
// 			sta[++top]='0'+i;
// 			if (vis[tr[u][i]]) {ans+=calc(); top=0;}
// 			dfs2(tr[u][i]);
// 		}
// 	}
// 	void solve() {
// 		pw[0]=1;
// 		for (int i=1; i<=n; ++i) pw[i]=pw[i-1]*base;
// 		for (int i=1; i<=n; ++i) h[i]=h[i-1]*base+s[i];
// 		for (int i=1; i<=n; ++i) ++cnt[rk[i]=s[i]];
// 		for (int i=1; i<=m; ++i) cnt[i]+=cnt[i-1];
// 		for (int i=1; i<=n; ++i) sa[cnt[rk[i]]--]=i;
// 		for (int w=1,p; ; w<<=1,m=p) {
// 			p=0;
// 			for (int i=n; i>n-w; --i) id[++p]=i;
// 			for (int i=1; i<=n; ++i) if (sa[i]>w) id[++p]=sa[i]-w;
// 			for (int i=1; i<=m; ++i) cnt[i]=0;
// 			for (int i=1; i<=n; ++i) ++cnt[px[i]=rk[id[i]]];
// 			for (int i=1; i<=m; ++i) cnt[i]+=cnt[i-1];
// 			for (int i=n; i; --i) sa[cnt[px[i]]--]=id[i];
// 			for (int i=1; i<=n; ++i) oldrk[i]=rk[i];
// 			p=0;
// 			for (int i=1; i<=n; ++i) rk[sa[i]]=cmp(sa[i-1], sa[i], w)?p:++p;
// 			if (p==n) break;
// 		}
// 		// for (int i=1; i<=n; ++i) printf("%d%c", sa[i], " \n"[i==n]);
// 		for (int i=2; i<=n; ++i) {
// 			int l=0, r=min(n-sa[i-1]+1, n-sa[i]+1), mid;
// 			while (l<=r) {
// 				mid=(l+r)>>1;
// 				if (!mid||hashing(sa[i-1], sa[i-1]+mid-1)==hashing(sa[i], sa[i]+mid-1)) l=mid+1;
// 				else r=mid-1;
// 			}
// 			ht[i]=l-1;
// 		}
// 		// for (int i=1; i<=n; ++i) cout<<ht[i]<<' '; cout<<endl;
// 		for (int i=2,now=ins(s+sa[1]); i<=n; ++i) {
// 			int len=n-sa[i-1]+1;
// 			for (int j=1; j<=len-ht[i]; ++j) now=back[now];
// 			now=ins(s+sa[i]+ht[i], now);
// 		}
// 		// cout<<"tot: "<<tot<<endl;
// 		vis[0]=1; dfs1(0);
// 		dfs2(0);
// 		printf("%lld\n", ans);
// 	}
// }

// namespace task2{
// 	vector<int> to[N];
// 	ll suf[N], ans, val;
// 	int len[N], fail[N], tr[N][2], endpos[N], now, tot;
// 	void init() {fail[now=tot=val=0]=-1; tr[0][0]=tr[0][1]=0;}
// 	void ins(char c) {
// 		c-='0';
// 		int cur=++tot;
// 		fail[cur]=tr[cur][0]=tr[cur][1]=0;
// 		len[cur]=len[now]+1;
// 		int p, q;
// 		for (p=now; ~p&&!tr[p][c]; tr[p][c]=cur,p=fail[p]);
// 		if (p==-1) fail[cur]=0;
// 		else if (len[q=tr[p][c]]==len[p]+1) fail[cur]=q;
// 		else {
// 			int cln=++tot;
// 			len[cln]=len[p]+1;
// 			fail[cln]=fail[q];
// 			tr[cln][0]=tr[q][0], tr[cln][1]=tr[q][1];
// 			for (; ~p&&tr[p][c]==q; tr[p][c]=cln,p=fail[p]);
// 			fail[cur]=fail[q]=cln;
// 		}
// 		vis[now=cur]=1;
// 		endpos[cur]=len[cur];
// 		val+=len[cur]-len[fail[cur]];
// 	}
// 	namespace sam{
// 		ll ans;
// 		int len[N], fail[N], tr[N][2], now, tot;
// 		void init() {fail[now=tot=ans=0]=-1; tr[0][0]=tr[0][1]=0;  }
// 		void ins(char c) {
// 			c-='0';
// 			int cur=++tot;
// 			fail[cur]=tr[cur][0]=tr[cur][1]=0;
// 			len[cur]=len[now]+1;
// 			int p, q;
// 			for (p=now; ~p&&!tr[p][c]; tr[p][c]=cur,p=fail[p]);
// 			if (p==-1) fail[cur]=0;
// 			else if (len[q=tr[p][c]]==len[p]+1) fail[cur]=q;
// 			else {
// 				int cln=++tot;
// 				len[cln]=len[p]+1;
// 				fail[cln]=fail[q];
// 				tr[cln][0]=tr[q][0], tr[cln][1]=tr[q][1];
// 				for (; ~p&&tr[p][c]==q; tr[p][c]=cln,p=fail[p]);
// 				fail[cur]=fail[q]=cln;
// 			}
// 			now=cur;
// 			ans+=len[cur]-len[fail[cur]];
// 		}
// 	}
// 	void calc(int l, int r) {
// 		sam::init();
// 		for (int i=l; i<=r; ++i) sam::ins(s[i]);
// 		ans+=sam::ans;
// 	}
// 	void dfs1(int u) {for (auto v:to[u]) dfs1(v), endpos[u]=endpos[v];}
// 	void dfs2(int u, int lst) {
// 		if (vis[u]||to[u].size()!=1) {
// 			if (~lst) {
// 				if (len[u]==endpos[u]) ans+=suf[len[u]-len[lst]]; //, cout<<"at: "<<u<<' '<<len[u]-len[lst]<<endl;
// 				else calc(endpos[u]-len[u]+1, endpos[u]-len[lst]);
// 			}
// 			lst=u;
// 		}
// 		for (auto v:to[u]) dfs2(v, lst);
// 	}
// 	void solve() {
// 		init();
// 		reverse(s+1, s+n+1);
// 		// cout<<"s: "; for (int i=1; i<=n; ++i) cout<<s[i]; cout<<endl;
// 		for (int i=1; i<=n; ++i) ins(s[i]), suf[i]=val;
// 		// cout<<"len: "; for (int i=1; i<=tot; ++i) cout<<len[i]<<' '; cout<<endl;
// 		// cout<<"fail: "; for (int i=1; i<=tot; ++i) cout<<fail[i]<<' '; cout<<endl;
// 		// cout<<"vis: "; for (int i=1; i<=tot; ++i) cout<<vis[i]<<' '; cout<<endl;
// 		for (int i=1; i<=tot; ++i) to[fail[i]].pb(i);
// 		vis[0]=1; dfs1(0); dfs2(0, -1);
// 		printf("%lld\n", ans);
// 	}
// }

namespace task{
	vector<int> to[N];
	ll suf[N], ans, val;
	vector<pair<int, int>> que;
	int endpos[N], pos[N], bel[N], tem[N], cnt[N], lst[N], sqr;
	namespace sam{
		int len[N], fail[N], tr[N][2], now, tot;
		void init() {fail[now=tot=ans=0]=-1; tr[0][0]=tr[0][1]=0;}
		int ins(char c) {
			c-='0';
			int cur=++tot;
			fail[cur]=tr[cur][0]=tr[cur][1]=0;
			len[cur]=len[now]+1;
			int p, q;
			for (p=now; ~p&&!tr[p][c]; tr[p][c]=cur,p=fail[p]);
			if (p==-1) fail[cur]=0;
			else if (len[q=tr[p][c]]==len[p]+1) fail[cur]=q;
			else {
				int cln=++tot;
				len[cln]=len[p]+1;
				fail[cln]=fail[q];
				tr[cln][0]=tr[q][0], tr[cln][1]=tr[q][1];
				for (; ~p&&tr[p][c]==q; tr[p][c]=cln,p=fail[p]);
				fail[cur]=fail[q]=cln;
			}
			vis[now=cur]=1;
			endpos[cur]=len[cur];
			val+=len[cur]-len[fail[cur]];
			return cur;
		}
	}
	namespace bit{
		ll bit1[N], bit2[N];
		inline void add(ll* bit, int i, ll dat) {for (; i<=n; i+=i&-i) bit[i]+=dat;}
		inline ll query(ll* bit, int i) {ll ans=0; for (; i; i-=i&-i) ans+=bit[i]; return ans;}
		inline void add(int l, int r, ll dat) {add(bit1, l, dat), add(bit2, l, dat*l); add(bit1, r+1, -dat), add(bit2, r+1, -dat*(r+1));}
		// inline ll query(int i) {return (i+1)*query(bit1, i)-query(bit2, i);}
		inline ll query(int l, int r) {return ((r+1)*query(bit1, r)-query(bit2, r))-(l*query(bit1, l-1)-query(bit2, l-1));}
	}
	// namespace bit{
	// 	ll bit1[N], bit2[N], sum1[N], sum2[N];
	// 	inline void add(ll* bit, ll* sum, int i, ll dat) {bit[i]+=dat; sum[bel[i]]+=dat;}
	// 	inline ll query(ll* bit, ll* sum, int i) {
	// 		int id=bel[i]; ll ans=0;
	// 		for (int i=1; i<id; ++i) ans+=sum[i];
	// 		for (; bel[i]==id; --i) ans+=bit[i];
	// 		return ans;
	// 	}
	// 	inline void add(int l, int r, ll dat) {add(bit1, sum1, l, dat), add(bit2, sum2, l, dat*l); add(bit1, sum1, r+1, -dat), add(bit2, sum2, r+1, -dat*(r+1));}
	// 	inline ll query(int l, int r) {return ((r+1)*query(bit1, sum1, r)-query(bit2, sum2, r))-(l*query(bit1, sum1, l-1)-query(bit2, sum2, l-1));}
	// }
	namespace lct{
		int rev[N], fa[N], val[N], son[N][2], tag[N], mlen[N], sta[N], top;
		#define son(a, b) son[a][b]
		#define loc(a) (son(fa[a], 1)==a)
		#define isrot(a) (son(fa[a], 0)!=a&&son(fa[a], 1)!=a)
		void spread(int a) {
			if (rev[a]) {
				if (son(a, 0)) swap(son(son(a, 0), 0), son(son(a, 0), 1)), rev[son(a, 0)]^=1;
				if (son(a, 1)) swap(son(son(a, 1), 0), son(son(a, 1), 1)), rev[son(a, 1)]^=1;
				rev[a]=0;
			}
			if (tag[a]) {
				if (son(a, 0)) val[son(a, 0)]=tag[son(a, 0)]=tag[a];
				if (son(a, 1)) val[son(a, 1)]=tag[son(a, 1)]=tag[a];
				tag[a]=0;
			}
		}
		void ror(int x) {
			int y=fa[x], z=fa[y], k=loc(x);
			if (!isrot(y)) son(z, loc(y))=x; fa[x]=z;
			son(y, k)=son(x, k^1); fa[son(x, k^1)]=y;
			son(x, k^1)=y; fa[y]=x;
		}
		void upd(int x) {if (!isrot(x)) upd(fa[x]); spread(x);}
		void splay(int x) {
			upd(x);
			for (int f; f=fa[x],!isrot(x); ror(x))
				if (!isrot(f)) loc(x)^loc(f)?ror(x):ror(f);
		}
		int findrt(int x) {
			spread(x);
			while (son(x, 0)) x=son(x, 0), spread(x);
			splay(x);
			return x;
		}
		int access(int x, int now) {
			// cout<<"access: "<<x<<endl;
			int lst;
			// for (lst=0; x; lst=x,x=fa[x]) {
			// 	// cout<<"x: "<<x<<endl;
			// 	splay(x), son(x, 1)=lst;
			// 	x=findrt(x);
			// 	// cout<<"val: "<<lst<<' '<<val[x]<<endl;
			// 	if (lst&&val[x]) bit::add(val[x]-(mlen[lst]-1)+1, val[x]-mlen[x]+1, -1); //, cout<<"add: "<<val[x]-(mlen[lst]-1)+1<<' '<<val[x]-mlen[x]+1<<' '<<-1<<endl;
			// }
			for (lst=0; x; lst=x,x=fa[x]) {
				splay(x);
				if (val[x]) bit::add(val[x]-sam::len[x]+1, val[x]-sam::len[fa[x]], -1);
				son(x, 0)=lst;
			}
			// splay(x=pos[now]);
			val[lst]=tag[lst]=now;
			bit::add(1, now, 1);
			return lst;
		}
	}
	// void dfs1(int u) {for (auto& v:to[u]) dfs1(v), endpos[u]=endpos[v];}
	// void dfs2(int u, int lst) {
	// 	if (vis[u]||to[u].size()!=1) {
	// 		if (~lst) {
	// 			if (sam::len[u]==endpos[u]) ans+=suf[sam::len[u]-sam::len[lst]]; //, cout<<"at: "<<u<<' '<<len[u]-len[lst]<<endl;
	// 			else que[endpos[u]-sam::len[lst]].pb(endpos[u]-sam::len[u]+1);
	// 		}
	// 		lst=u;
	// 	}
	// 	for (auto& v:to[u]) dfs2(v, lst);
	// }
	void solve() {
		// sqr=pow(n, 0.5);
		// for (int i=1; i<=n; ++i) bel[i]=(i-1)/sqr+1;
		sam::init();
		memset(endpos, 0x3f, sizeof(endpos));
		reverse(s+1, s+n+1);
		// cout<<"s: "; for (int i=1; i<=n; ++i) cout<<s[i]; cout<<endl;
		for (int i=1; i<=n; ++i) pos[i]=sam::ins(s[i]), suf[i]=val;
		// cout<<"len: "; for (int i=1; i<=tot; ++i) cout<<len[i]<<' '; cout<<endl;
		// cout<<"fail: "; for (int i=1; i<=tot; ++i) cout<<fail[i]<<' '; cout<<endl;
		// cout<<"vis: "; for (int i=1; i<=tot; ++i) cout<<vis[i]<<' '; cout<<endl;
		for (int i=1; i<=sam::tot; ++i) to[sam::fail[i]].pb(i);
		vis[0]=1;
		// dfs1(0); dfs2(0, -1);
		for (int i=1; i<=sam::tot; ++i) ++cnt[sam::len[i]];
		for (int i=1; i<=n; ++i) cnt[i]+=cnt[i-1];
		for (int i=1; i<=sam::tot; ++i) tem[cnt[sam::len[i]]--]=i;
		for (int i=sam::tot; i; --i) chkmin(endpos[sam::fail[tem[i]]], endpos[tem[i]]);
		// for (int i=sam::tot; i; --i) endpos[sam::fail[tem[i]]]=endpos[tem[i]];
		for (int i=1,u; i<=sam::tot; ++i) {
			u=tem[i];
			lst[u]=lst[sam::fail[u]];
			if (vis[u]||to[u].size()!=1) {
				if (~lst[u]) {
					if (sam::len[u]==endpos[u]) ans+=suf[sam::len[u]-sam::len[lst[u]]]; //, cout<<"at: "<<u<<' '<<len[u]-len[lst]<<endl;
					else que.pb({endpos[u]-sam::len[u]+1, endpos[u]-sam::len[lst[u]]});
				}
				lst[u]=u;
			}
		}
		for (int i=1; i<=sam::tot; ++i) lct::mlen[i]=sam::len[lct::fa[i]=sam::fail[i]]+1;
		sort(que.begin(), que.end(), [](pair<int, int> a, pair<int, int> b){return a.sec<b.sec;});
		int lst=0;
		for (auto it:que) {
			while (lst<it.sec) ++lst, lct::access(pos[lst], lst);
			ans+=bit::query(it.fir, it.sec);
		}
		// cout<<"lst: "<<lst<<endl;
		printf("%lld\n", ans);
	}
}

signed main()
{
	freopen("string.in", "r", stdin);
	freopen("string.out", "w", stdout);

	scanf("%s", s+1);
	n=strlen(s+1);
	// force::solve();
	// task1::solve();
	// task2::solve();
	// sufsort::solve();
	task::solve();
	
	return 0;
}
posted @ 2022-05-27 20:59  Administrator-09  阅读(1)  评论(0编辑  收藏  举报