题解 string

传送门

考试的时候只来得及糊了个\(n^4\)的暴力,结果考完发现\(n^2\)\(n^4\)还好写

题意就是就是要求把一堆字符串的前后缀拼起来之后在原串中出现了多少次
然而前后缀可以有很多,再枚举组合就炸没了

先考虑\(n^2\) 写法:
可以先预处理出所有前后缀,分别扔到map里
枚举原串中的每一个位置作为连接点,向前/后分别枚举长度,累加匹配到的前/后缀个数,最后乘起来

发现瓶颈在于枚举长度,考虑优化(以后缀为例
根据题解提示,首先发现所有匹配到的后缀都一定被最长的那个包含,而且能匹配上的长度还有单调性
所以如果能在每个后缀上挂载一个它包含的所有更短后缀的信息,就可以直接二分出能匹配的最长后缀,并利用挂载的信息完成统计
至于具体实现,因为还要统计它包含的后缀的信息,所以可以建棵trie树
统计的时候在trie树上跑遍dfs就可以了

有个小细节:这里二分最长匹配长度的时候是check这个hash值在不在unordered_map中,所以一定要用mp.find(),否则自动创建了就会炸锅

Code:

#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ull unsigned long long 
#define ll long long 
//#define int long long 

inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int sl, m;
char s[N], a[3005][5005], st[N];
char* b[N];
int len[N];
const ll base=13131, mod=1206927149;
ll h[N], sh[3005][5005], p[N];

namespace force{
	ll ans;
	inline ll hashing(ll* h, int l, int r) {return ((h[r]-h[l-1]*p[r-l+1]%mod)%mod+mod)%mod;}
	void solve() {
		//cout<<double(sizeof(sh)*2)/1024/1024<<endl;
		p[0]=1;
		for (int i=1; i<N; ++i) p[i]=p[i-1]*base%mod;
		for (int i=1; i<=sl; ++i) h[i]=(h[i-1]*base%mod+s[i])%mod;
		for (int i=1; i<=m; ++i) 
			for (int j=1; j<=len[i]; ++j) 
				sh[i][j]=(sh[i][j-1]*base%mod+a[i][j])%mod;
		for (int i=1; i<=m; ++i) 
			for (int l1=1; l1<=len[i]; ++l1) {
				ll t = hashing(sh[i], len[i]-l1+1, len[i]);
				//cout<<"t: "<<t<<endl;
				//cout<<"try: "; for (int k=len[i]-l1+1; k<=len[i]; ++k) cout<<a[i][k]; cout<<endl;
				for (int pos=l1; pos<=sl; ++pos) {
					//for (int k=pos-l1+1; k<=pos; ++k) cout<<s[k]; cout<<' '<<hashing(h, pos-l1+1, pos)<<' '<<t<<' '; cout<<endl;
					if (hashing(h, pos-l1+1, pos)==t) {
						//cout<<"match1: "<<i<<' '<<l1<<' '; for (int k=len[i]-l1+1; k<=len[i]; ++k) cout<<a[i][k]; cout<<endl;
						for (int j=1; j<=m; ++j)
							for (int l2=1; l2<=min(len[j], sl-pos); ++l2) {
								ll t2=hashing(sh[j], 1, l2);
								if (hashing(h, pos+1, pos+l2)==t2) ++ans;
							}
					}
				}
			}
		printf("%lld\n", ans);
		exit(0);
	}
}

namespace task1{
	unordered_map<ull, ll> mp;
	ll ans;
	inline ll hashing(ll* h, int l, int r) {return ((h[r]-h[l-1]*p[r-l+1]%mod)%mod+mod)%mod;}
	void solve() {
		//cout<<double(sizeof(sh)*2)/1024/1024<<endl;
		p[0]=1;
		for (int i=1; i<N; ++i) p[i]=p[i-1]*base%mod;
		for (int i=1; i<=sl; ++i) h[i]=(h[i-1]*base%mod+s[i])%mod;
		for (int i=1; i<=m; ++i) 
			for (int j=1; j<=len[i]; ++j) 
				sh[i][j]=(sh[i][j-1]*base%mod+a[i][j])%mod;
		
		for (int i=1; i<=m; ++i) 
			for (int j=1; j<=len[i]; ++j) 
				++mp[hashing(sh[i], 1, j)];
		
		for (int i=1; i<=m; ++i) 
			for (int l1=1; l1<=len[i]; ++l1) {
				ll t = hashing(sh[i], len[i]-l1+1, len[i]), t2;
				for (int pos=l1; pos<=sl; ++pos) {
					if (hashing(h, pos-l1+1, pos)==t) {
						for (int j=pos+1; j<=sl; ++j) {
							t2 = hashing(h, pos+1, j);
							if (mp.find(t2)!=mp.end()) 
								ans+=mp[t2];
						}
					}
				}
			}
		printf("%lld\n", ans);
		exit(0);
	}
}

namespace task2{
	ull h[N], p[N], sh[3005][5005];
	unordered_map<ull, ll> mp1, mp2;
	ll ans;
	inline ull hashing(ull* h, int l, int r) {return h[r]-h[l-1]*p[r-l+1];}
	void solve() {
		p[0]=1;
		for (int i=1; i<N; ++i) p[i]=p[i-1]*base;
		for (int i=1; i<=sl; ++i) h[i]=h[i-1]*base+s[i];
		for (int i=1; i<=m; ++i) 
			for (int j=1; j<=len[i]; ++j) 
				sh[i][j]=sh[i][j-1]*base+a[i][j];
		
		for (int i=1; i<=m; ++i) 
			for (int j=len[i]; j; --j) 
				++mp1[hashing(sh[i], j, len[i])];
				
		for (int i=1; i<=m; ++i) 
			for (int j=1; j<=len[i]; ++j) 
				++mp2[hashing(sh[i], 1, j)];
		
		ll t1, t2, t;
		for (int i=2; i<=sl; ++i) {
			t1=0, t2=0;
			for (int j=i-1; j; --j) {
				t=hashing(h, j, i-1);
				if (mp1.find(t)!=mp1.end()) t1+=mp1[t];
			}
			for (int j=i; j<=sl; ++j) {
				t=hashing(h, i, j);
				if (mp2.find(t)!=mp2.end()) t2+=mp2[t];
			}
			ans+=t1*t2;
		}
		
		printf("%lld\n", ans);
		exit(0);
	}
}

namespace task{
	unordered_map<ull, ll> mp1, mp2;
	ll ans;
	ull h1[N], h2[N], p[N];
	inline ull hashing(ull* h, int l, int r) {return h[r]-h[l-1]*p[r-l+1];}
	inline ull hash2(int r, int l) {return h2[r]-h2[l+1]*p[l-r+1];}
	const int SIZE=N*50;
	int tot, son[SIZE][26], cnt[SIZE], size[SIZE];
	ull sh[SIZE];
	#define son(a, b) son[a][b]
	struct trie1{
		int rot;
		trie1(){rot=++tot;}
		void ins(char* s, int len) {
			int p=rot, u;
			ll h=0;
			for (int dep=1; dep<=len; ++dep) {
				u=son(p, s[dep]-'a');
				h=h*base+s[dep];
				if (!u) {son(p, s[dep]-'a')=u=++tot; sh[u]=h;}
				++cnt[u];
				p=u;
			}
		}
		void dfs(int u) {
			//cout<<"dfs "<<u<<endl;
			size[u]+=cnt[u];
			//if (mp1.find(sh[u])!=mp1.end()) puts("same hashval");
			mp1[sh[u]]=size[u];
			//cout<<sh[u]<<' '<<size[u]<<endl;
			for (int i=0; i<26; ++i) 
				if (son(u, i)) {
					size[son(u, i)]+=size[u];
					dfs(son(u, i));
				}
		}
	}tr1;
	struct trie2{
		int rot;
		trie2(){rot=++tot;}
		void ins(char* s, int len) {
			int p=rot, u;
			ll h=0;
			for (int dep=len; dep; --dep) {
				u=son(p, s[dep]-'a');
				h=h*base+s[dep];
				if (!u) {son(p, s[dep]-'a')=u=++tot; sh[u]=h;}
				++cnt[u];
				p=u;
			}
		}
		void dfs(int u) {
			//cout<<"dfs "<<u<<endl;
			size[u]+=cnt[u];
			//if (mp2.find(sh[u])!=mp2.end()) puts("same hashval");
			mp2[sh[u]]=size[u];
			for (int i=0; i<26; ++i) 
				if (son(u, i)) {
					size[son(u, i)]+=size[u];
					dfs(son(u, i));
				}
		}
	}tr2;
	void solve() {
		for (int i=1; i<=m; ++i) tr1.ins(b[i], len[i]), tr2.ins(b[i], len[i]);
		//cout<<"tot: "<<tot<<endl;
		p[0]=1;
		for (int i=1; i<N; ++i) p[i]=p[i-1]*base;
		for (int i=1; i<=sl; ++i) h1[i]=h1[i-1]*base+s[i];
		for (int i=sl; i; --i) h2[i]=h2[i+1]*base+s[i];
		tr1.dfs(tr1.rot); tr2.dfs(tr2.rot);
		ll t1, t2;
		int l, r, mid;
		for (int i=2; i<=sl; ++i) {
			l=0, r=i;
			while (l<=r) {
				mid=(l+r)>>1;
				if (mp2.find(hash2(i-mid+1, i-1))!=mp2.end()) l=mid+1;
				else r=mid-1;
			}
			if (mp2.find(hash2(i-l+2, i-1))==mp2.end()) continue;
			t1=mp2[hash2(i-l+2, i-1)]; //, cout<<"test: "<<i-l+2<<' '<<i-1<<endl;
			//cout<<"l: "<<l-1<<endl;
			//cout<<t1<<endl;
			
			l=0, r=sl-i;
			while (l<=r) {
				mid=(l+r)>>1;
				if (mp1.find(hashing(h1, i, i+mid))!=mp1.end()) l=mid+1;
				else r=mid-1;
			}
			if (mp1.find(hashing(h1, i, i+l-1))==mp1.end()) continue;
			t2=mp1[hashing(h1, i, i+l-1)];
			//cout<<"l: "<<l-1<<endl;
			//cout<<t2<<endl;
			
			ans+=t1*t2;
			//cout<<i<<" += "<<t1*t2<<' '<<t1<<' '<<t2<<endl;
		}
		printf("%lld\n", ans);
		exit(0);
	}
}

signed main()
{
	scanf("%s%d", s+1, &m); sl=strlen(s+1);
	for (int i=1; i<=m; ++i) {
		scanf("%s", st+1);
		len[i]=strlen(st+1);
		b[i]=new char[len[i]+3];
		memcpy(b[i]+1, st+1, sizeof(char)*(len[i]+1));
	}
	//if (m<=50) force::solve();
	//else task1::solve();
	//task2::solve();
	task::solve();
	
	return 0;
}
posted @ 2021-07-27 21:44  Administrator-09  阅读(14)  评论(0编辑  收藏  举报