题解 神牛养成计划

传送门

考场上看错题当成子串,建了棵后缀树然后复杂度炸了

对原串和反串分别建出trie树,即为求同时在两棵树的给定子树内的点权和
于是dfs一棵树,建出主席树查询另一棵树dfs序范围即可

题解有另一种做法:
首先我们把N个字符串按前缀的字典序排序。然后将这个顺序下N个字符串的后缀建成一棵可持久化tire树。
对于每个询问,我们只需根据s1求出符合前缀的区间,再根据s2在这个区间后缀所组成的tire树上走,最后便可统计出答案。

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 2000010
#define ll long long
#define fir first
#define sec second
#define ull unsigned long long
//#define int long long

inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, m;

// namespace force{
// 	int mlen;
// 	string s[N], s1, s2;
// 	vector<ull> h[N];
// 	ull p[N], h1, h2;
// 	const ull base=13131;
// 	inline ull hashing(vector<ull>& h, int l, int r) {return h[r]-h[l-1]*p[r-l+1];}
// 	void solve() {
// 		cin>>n;
// 		for (int i=1; i<=n; ++i) {
// 			cin>>s[i];
// 			mlen=max(mlen, int(s[i].length()));
// 			ull lst=0;
// 			h[i].push_back(0);
// 			for (int j=0; j<s[i].length(); ++j) {
// 				lst=lst*base+s[i][j];
// 				h[i].push_back(lst);
// 			}
// 		}
// 		p[0]=1;
// 		for (int i=1; i<=mlen; ++i) p[i]=p[i-1]*base;
// 		cin>>m;
// 		for (int i=1,ans=0; i<=m; ++i) {
// 			cin>>s1>>s2;
// 			h1=h2=0;
// 			// cout<<"len: "<<s1.length()<<' '<<s2.length()<<endl;
// 			for (int j=0; j<s1.length(); ++j) h1=h1*base+(s1[j]-'a'+ans)%26+'a';
// 			for (int j=0; j<s2.length(); ++j) h2=h2*base+(s2[j]-'a'+ans)%26+'a';
// 			// cout<<h1<<endl;
// 			ans=0;
// 			for (int j=1; j<=n; ++j) {
// 				// cout<<"h: "<<hashing(h[j], 1, s1.length())<<' '<<hashing(h[j], s[j].length()-s2.length()+1, s[j].length())<<endl;
// 				if (s[j].length()>=max(s1.length(), s2.length()) && hashing(h[j], 1, s1.length())==h1 && hashing(h[j], s[j].length()-s2.length()+1, s[j].length())==h2) ++ans;
// 			}
// 			printf("%d\n", ans);
// 		}
// 	}
// }
// 
// namespace task1{
// 	char s[N], s1[N], s2[N];
// 	queue<int> q2;
// 	queue<pair<int, int>> q;
// 	int lson[N], rson[N], sum[N], rot[N], now, tot2;
// 	int head[N], dep[N], id[N], in[N], siz[N], in2[N], siz2[N], ord, ord2, size;
// 	struct edge{int to, next;}e[N];
// 	int len[N], fail[N], cnt[N], deg[N], tot, tal;
// 	struct node{int to, next; char ch;}nod[N<<2];
// 	struct turn{
// 		int head;
// 		turn():head(-1){};
// 		inline int operator [] (int t) {
// 			for (int i=head; ~i; i=nod[i].next)
// 				if (nod[i].ch==t) return nod[i].to;
// 			return 0;
// 		}
// 		inline void insert(char c, int t) {nod[++tal]={t, head, c}; head=tal;}
// 		inline void upd(int c, int t) {
// 			for (int i=head; ~i; i=nod[i].next)
// 				if (nod[i].ch==c) {nod[i].to=t; break;}
// 		}
// 	}tr[N];
// 	inline void add(int s, int t) {e[++size]={t, head[s]}; head[s]=size;}
// 	void init() {tot=1; memset(head, -1, sizeof(head)); fail[0]=-1;}
// 	#define ls(p) lson[p]
// 	#define rs(p) rson[p]
// 	#define sum(p) sum[p]
// 	#define pushup(p) sum(p)=sum(ls(p))+sum(rs(p))
// 	void upd(int& p1, int p2, int tl, int tr, int pos, int val) {
// 		p1=++tot2;
// 		if (tl==tr) {sum(p1)=sum(p2)+val; return ;}
// 		int mid=(tl+tr)>>1;
// 		if (pos<=mid) upd(ls(p1), ls(p2), tl, mid, pos, val), rs(p1)=rs(p2);
// 		else upd(rs(p1), rs(p2), mid+1, tr, pos, val), ls(p1)=ls(p2);
// 		pushup(p1);
// 	}
// 	int query(int p1, int p2, int tl, int tr, int ql, int qr) {
// 		if (!p1) return 0;
// 		if (ql<=tl&&qr>=tr) return sum(p1)-sum(p2);
// 		int mid=(tl+tr)>>1, ans=0;
// 		if (ql<=mid) ans+=query(ls(p1), ls(p2), tl, mid, ql, qr);
// 		if (qr>mid) ans+=query(rs(p1), rs(p2), mid+1, tr, ql, qr);
// 		return ans;
// 	}
// 	int ins(int c, int now) {
// 		int cur=tr[now][c];
// 		len[cur]=len[now]+1;
// 		int p, q;
// 		for (p=fail[now]; ~p&&!tr[p][c]; tr[p].insert(c, cur),p=fail[p]);
// 		if (p==-1) fail[cur]=0;
// 		else if (len[q=tr[p][c]]==len[p]+1) fail[cur]=q;
// 		else {
// 			int cln=++tot;
// 			// cout<<"pos1"<<endl;
// 			len[cln]=len[p]+1;
// 			fail[cln]=fail[q];
// 			for (int i=tr[q].head; ~i; i=nod[i].next) if (len[nod[i].to]) tr[cln].insert(nod[i].ch, nod[i].to);
// 			for (; ~p&&tr[p][c]==q; tr[p].upd(c, cln),p=fail[p]);
// 			fail[cur]=fail[q]=cln;
// 		}
// 		return cur;
// 	}
// 	void build() {
// 		// for (int i=0; i<26; ++i) if (tr[0][i]) q.push({i, 0});
// 		for (int i=tr[0].head; ~i; i=nod[i].next) q.push({nod[i].ch, 0});
// 		pair<int, int> u;
// 		while (q.size()) {
// 			u=q.front(); q.pop();
// 			int now=ins(u.fir, u.sec);
// 			// for (int i=0; i<26; ++i) if (tr[now][i]) q.push({i, now});
// 			for (int i=tr[now].head; ~i; i=nod[i].next) q.push({nod[i].ch, now});
// 		}
// 	}
// 	void dfs1(int u) {
// 		// cout<<"dfs1: "<<u<<endl;
// 		siz[u]=1; in[u]=++ord;
// 		if (cnt[u]) upd(rot[ord], rot[ord-1], 1, tot, in2[u], cnt[u]); //, cout<<"upd: "<<ord<<' '<<in2[u]<<endl;
// 		else rot[ord]=rot[ord-1];
// 		for (int i=head[u],v; ~i; i=e[i].next) {
// 			v = e[i].to;
// 			dfs1(v);
// 			siz[u]+=siz[v];
// 		}
// 	}
// 	void dfs2(int u) {
// 		// cout<<"dfs2: "<<u<<endl;
// 		siz2[u]=1; in2[u]=++ord2;
// 		for (int i=tr[u].head,v; ~i; i=nod[i].next) if (len[v=nod[i].to]==len[u]+1) {
// 			dfs2(v);
// 			siz2[u]+=siz2[v];
// 		}
// 	}
// 	void solve() {
// 		// cout<<double(sizeof(lson)*3+sizeof(rot)*14+sizeof(nod)+sizeof(tr)+sizeof(s)*3)/1000/1000<<endl;
// 		scanf("%d", &n);
// 		init();
// 		for (int i=1,k; i<=n; ++i) {
// 			scanf("%s", s+1);
// 			k=strlen(s+1);
// 			int u=0;
// 			for (int j=1,v; j<=k; ++j) {
// 				if (!(v=tr[u][s[j]-'a'])) tr[u].insert(s[j]-'a', v=++tot);
// 				u=v;
// 			}
// 			// cout<<"add: "<<u<<endl;
// 			++cnt[u];
// 		}
// 		build();
// 		for (int i=2; i<=tot; ++i) add(fail[i], i); //, cout<<"add: "<<fail[i]<<' '<<i<<' '<<in2[i]<<endl;
// 		dfs2(0); dfs1(0);
// 		scanf("%d", &m);
// 		for (int i=1,lst=0,t1,t2,k1,k2,l; i<=m; ++i) {
// 			scanf("%s%s", s1+1, s2+1);
// 			k1=strlen(s1+1); k2=strlen(s2+1);
// 			for (int j=1; j<=k1; ++j) s1[j]=(s1[j]-'a'+lst)%26+'a';
// 			for (int j=1; j<=k2; ++j) s2[j]=(s2[j]-'a'+lst)%26+'a';
// 			t1=t2=l=0;
// 			for (int j=1,v; j<=k1; ++j) {
// 				// cout<<"t1: "<<t1<<endl;
// 				v=tr[t1][s1[j]-'a'];
// 				if (!v || len[v]!=len[t1]+1) {printf("%d\n", lst=0); goto jump;}
// 				t1=v;
// 				// cout<<"turn: "<<t1<<endl;
// 			}
// 			for (int j=1,v; j<=k2; ++j) {
// 				v=tr[t2][s2[j]-'a'];
// 				while (~t2&&!v) t2=fail[t2], l=len[t2];
// 				if (t2==-1) t2=l=0;
// 				else t2=v, ++l;
// 			}
// 			// cout<<l<<endl;
// 			if (l<k2) {printf("%d\n", lst=0); goto jump;}
// 			// cout<<"t: "<<t1<<' '<<t2<<endl;
// 			// cout<<"rg1: "<<in[t2]+siz[t2]-1<<' '<<in[t2]-1<<endl;
// 			// cout<<in2[t1]<<' '<<siz2[t1]<<endl;
// 			// cout<<in[t2]+siz[t2]-1<<endl;
// 			// cout<<"rg2: "<<in2[t1]<<' '<<in2[t1]+siz2[t1]-1<<endl;
// 			printf("%d\n", lst=query(rot[in[t2]+siz[t2]-1], rot[in[t2]-1], 1, tot, in2[t1], in2[t1]+siz2[t1]-1));
// 			// cout<<query(6, 0, 1, 8, 1, 8)<<endl;
// 			// cout<<tot<<' '<<ord2<<endl;
// 			// cout<<"tot: "<<tot<<endl;
// 			// cout<<in[6]<<endl;
// 			jump: ;
// 		}
// 	}
// }

namespace task{
	char s[N], s1[N], s2[N];
	vector<int> buc[N];
	int lson[N], rson[N], sum[N], rot[N], pos[N], now;
	int in[N], siz[N], in2[N], siz2[N], cnt[N], ord, ord2, tot1, tot2, tot3, tal;
	struct node{int to, next; char ch;}nod[N<<2];
	struct turn{
		int head;
		turn():head(-1){}
		inline int operator [] (int t) {
			for (int i=head; ~i; i=nod[i].next)
				if (nod[i].ch==t) return nod[i].to;
			return 0;
		}
		inline void insert(char c, int t) {nod[++tal]={t, head, c}; head=tal;}
		inline void upd(int c, int t) {
			for (int i=head; ~i; i=nod[i].next)
				if (nod[i].ch==c) {nod[i].to=t; break;}
		}
	}tr1[N], tr2[N];
	#define ls(p) lson[p]
	#define rs(p) rson[p]
	#define sum(p) sum[p]
	#define pushup(p) sum(p)=sum(ls(p))+sum(rs(p))
	void upd(int& p1, int p2, int tl, int tr, int pos, int val) {
		p1=++tot3;
		if (tl==tr) {sum(p1)=sum(p2)+val; return ;}
		int mid=(tl+tr)>>1;
		if (pos<=mid) upd(ls(p1), ls(p2), tl, mid, pos, val), rs(p1)=rs(p2);
		else upd(rs(p1), rs(p2), mid+1, tr, pos, val), ls(p1)=ls(p2);
		pushup(p1);
	}
	int query(int p1, int p2, int tl, int tr, int ql, int qr) {
		if (!p1) return 0;
		if (ql<=tl&&qr>=tr) return sum(p1)-sum(p2);
		int mid=(tl+tr)>>1, ans=0;
		if (ql<=mid) ans+=query(ls(p1), ls(p2), tl, mid, ql, qr);
		if (qr>mid) ans+=query(rs(p1), rs(p2), mid+1, tr, ql, qr);
		return ans;
	}
	int ins1(char* c) {
		int u=0, v;
		for (; *c; u=v,++c) if (!(v=tr1[u][*c-'a'])) tr1[u].insert(*c-'a', v=++tot1);
		return u;
	}
	int ins2(char* c) {
		int u=0, v;
		for (; *c; u=v,++c) if (!(v=tr2[u][*c-'a'])) tr2[u].insert(*c-'a', v=++tot2);
		return u;
	}
	void dfs2(int u) {
		// cout<<"dfs2: "<<u<<endl;
		siz2[u]=1; in2[u]=++ord2;
		for (int i=tr2[u].head,v; ~i; i=nod[i].next) {
			v = nod[i].to;
			dfs2(v);
			siz2[u]+=siz2[v];
		}
	}
	void dfs1(int u) {
		siz[u]=1; in[u]=++ord;
		if (!buc[u].size()) pos[ord]=now;
		else {
			for (auto it:buc[u]) ++now, upd(rot[now], rot[now-1], 1, ord2, in2[it], 1); //, cout<<"add: "<<in2[u]<<endl;
			pos[ord]=now;
		}
		for (int i=tr1[u].head,v; ~i; i=nod[i].next) {
			v = nod[i].to;
			dfs1(v);
			siz[u]+=siz[v];
		}
	}
	void solve() {
		// cout<<double(sizeof(lson)*3+sizeof(rot)*14+sizeof(nod)+sizeof(tr)+sizeof(s)*3)/1000/1000<<endl;
		scanf("%d", &n);
		for (int i=1,k,t1,t2; i<=n; ++i) {
			scanf("%s", s+1);
			k=strlen(s+1);
			t2=ins2(s+1);
			reverse(s+1, s+k+1);
			t1=ins1(s+1);
			// cout<<"t: "<<t1<<' '<<t2<<endl;
			buc[t1].push_back(t2);
		}
		dfs2(0); dfs1(0);
		scanf("%d", &m);
		for (int i=1,lst=0,t1,t2,k1,k2,l; i<=m; ++i) {
			scanf("%s%s", s1+1, s2+1);
			k1=strlen(s1+1); k2=strlen(s2+1);
			for (int j=1; j<=k1; ++j) s1[j]=(s1[j]-'a'+lst)%26+'a';
			for (int j=1; j<=k2; ++j) s2[j]=(s2[j]-'a'+lst)%26+'a';
			reverse(s2+1, s2+k2+1);
			t1=t2=0;
			for (int j=1,v; j<=k1; ++j) {
				// cout<<"t1: "<<t1<<endl;
				v=tr2[t1][s1[j]-'a'];
				if (!v) {printf("%d\n", lst=0); goto jump;}
				t1=v;
				// cout<<"turn: "<<t1<<endl;
			}
			for (int j=1,v; j<=k2; ++j) {
				// cout<<"t2: "<<t2<<endl;
				v=tr1[t2][s2[j]-'a'];
				if (!v) {printf("%d\n", lst=0); goto jump;}
				t2=v;
				// cout<<"turn: "<<t2<<endl;
			}
			// cout<<"t: "<<t1<<' '<<t2<<endl;
			// cout<<"rg: "<<in2[t1]<<' '<<in2[t1]+siz2[t1]-1<<endl;
			printf("%d\n", lst=query(rot[pos[in[t2]+siz[t2]-1]], rot[pos[in[t2]-1]], 1, ord2, in2[t1], in2[t1]+siz2[t1]-1));
			jump: ;
		}
	}
}

signed main()
{
	// force::solve();
	task::solve();

	return 0;
}
posted @ 2021-12-18 20:19  Administrator-09  阅读(2)  评论(0编辑  收藏  举报