【YBT2023寒假Day11 C】棕发少女(SA)(主席树)(二分)

棕发少女

题目链接:YBT2023寒假Day11 C

题目大意

给你三个字符串 a,b,s,我们设 F0=a,F1=b,Fi=Fi-1+Fi-2(加号是把两个字符串拼接起来)
然后多次询问每次给你 n,l,r,L,R,要你求 s 中 L~R 的子串在 Fn 中 l~r 的字符串出现了多少次。

思路

发现 \(F_n\) 长度达到了 \(1e9\),看起来直接搞不太能做。
那我们不妨假设 \(F_n\) 长度不大的时候要怎么做。

那这个其实是一个经典的问题,就是求有多少个 \(s\)\(L\sim |s|\)\(F_n\)\(i\sim |F_n|\) 的最长公共前缀 \(\geqslant R-L+1\),把两个字符串拼起来,中间加一个特殊字符 \(\#\),然后跑一个 SA,就是要 \(height\) 最小值大于等于 \(R-L+1\)
那看到满足这个条件的位置是一个区间,我们可以二分出这个区间,那问题就变成一个 \(l\sim r\) 的区间里有多少个在 SA 中的位置在我们二分出的区间中,这个用一个主席树可维护。

然后我们考虑 \(F_n\) 长度变得很大,那我们考虑利用上构造 \(F_n\) 的性质。
\(s\) 并不会有那么长,那我们考虑 \(s\) 子串会怎样出现在 \(F_n\) 中。
那可能是直接存在于每个 \(a/b\) 之中,也有可能是跨过了若干个 \(a/b\),那我们所需要做的就是求出它存在于的 \(ab\) 串的形式,算出对应的方案,再乘上这个 \(ab\) 串在总串中出现的概率。
(由于你算一个串的时候它子串都会被算上,所以你还要稍微容斥一下)

那这个显然是不对的,会有很多很多的串要你分别去算。
那为啥这么多呢?因为可能会跨越很多段,那我们能不能让它少跨越一点呢?
再看会我们的暴力,我们的暴力要做到多大,才能使得它少跨越一点呢。
我们看看如果要只跨越两段,要保证 \(F_n\) 至少是多大,那显然是 \(|F_n|>|s|\)

那我们小于的部分,就暴力,大于的部分,我们保证只会被跨越两段。
不过注意到,你这里的每一段不在是给出的 \(a,b\),而是当 \(|F_{k}|>|s|\) 时,\(F_{k-1}=A,F_k=B\)
每一段是这个 \(A,B\)
那我们思考一下最后的 \(F_n\) 被拆成的区间数,会发现如果拆成 \(A,B\) 还是不少,特别是当 \(|s|\) 很小的时候。
那我们可以不要让 \(|F_k|>|s|\),而是大于或大于等于一个固定的数(比如我是 \(|F_k|\geqslant 60000\),其中 \(\max|s|=30000\)
那最多的 \(AB\) 数量就是 \(q\dfrac{1e9}{6e4}\),在 \(q=1e4\) 的时候大概为 \(2e6\) 不到。

那考虑拆出来再看 \(l\sim r\) 的区间,发现是形如两边一个可能不完整的 \(A/B\) 加上中间一段完整的 \(AB\) 序列。
那我们要算的是两边自身的贡献,两边和旁边那个完整的贡献(如果只有两个不完整或者只有一个不完整的特判一下),完整序列里面每个自己的贡献,完整序列里面每个相邻对的贡献。
那自己的贡献就直接把 \(A/B\) 分别和 \(s\) 像暴力一样拼起来做就好。
那相邻对的我们就把 \(AA,AB,BA,BB\) 分别和 \(s\) 也拼起来做就好了。
(不过其实 \(AA\) 是不会存在的,可以去掉)
(原因是除了 \(F_0=a\),其它位置开头都是 \(b\),在拼接的地方不会 \(F_{i-1}\) 的尾部和 \(F_{i-2}\) 的头部都是 \(a\),唯一符合头部是 \(a\)\(F_0\)\(F_1\) 尾部是 \(b\)

不过要注意的是前面说过的,你算相邻的会把两个自身的都算进去,所以要减。
那相邻对肯定是枚举对(在完整段相邻对的时候),而是先枚举对统计上面每个对的个数,然后最后每种算一次贡献乘上它出现的次数。

然后就可以了。

代码

#include<cstdio>
#include<string>
#include<cstring>
#define ll long long


using namespace std;

const int N = 8e5 + 100;
//const int N = 1000;
int n, an, bn, fsz[N], An, Bn, log2_[N];
int sum[N];
char s[N], a[N], b[N], A[N], B[N];
string fab[5];

struct SA_work {
	char s[N];
	int n, rk[N << 1], sa[N << 1], height[N][21], tong[N << 1], x[N << 1], y[N << 1];
	int ls[N << 5], rs[N << 5], f[N << 5], rt[N << 1], tot, fir[N << 1];
	
	int add(int x, int l, int r, int pl) {
		int now = ++tot; ls[now] = ls[x]; rs[now] = rs[x]; f[now] = f[x];
		f[now]++; if (l == r) return now;
		int mid = (l + r) >> 1;
		if (pl <= mid) ls[now] = add(ls[now], l, mid, pl);
			else rs[now] = add(rs[now], mid + 1, r, pl);
		return now; 
	}
	
	int query(int now, int l, int r, int L, int R) {
		if (L > R) return 0;
		if (!now) return 0;
		if (L <= l && r <= R) return f[now];
		int mid = (l + r) >> 1, re = 0;
		if (L <= mid) re += query(ls[now], l, mid, L, R);
		if (mid < R) re += query(rs[now], mid + 1, r, L, R);
		return re;
	}
	
	void Sort(int m, int *x, int *y) {
		for (int i = 0; i <= m; i++) tong[i] = 0;
		for (int i = 1; i <= n; i++) tong[x[i]]++;
		for (int i = 1; i <= m; i++) tong[i] += tong[i - 1];
		for (int i = n; i >= 1; i--) sa[tong[fir[i]]--] = y[i];
	}
	
	void SA() {
		int m = max(n, 3000);
		for (int i = 1; i <= n; i++) x[i] = s[i], y[i] = i;
		for (int i = 1; i <= n; i++) fir[i] = x[y[i]];
		Sort(m, x, y);
		
		for (int w = 1; w <= n; w <<= 1) {
			int ynum = 0;
			for (int i = n - w + 1; i <= n; i++)
				y[++ynum] = i;
			for (int i = 1; i <= n; i++)
				if (sa[i] > w) y[++ynum] = sa[i] - w;
			for (int i = 1; i <= n; i++) fir[i] = x[y[i]];
			Sort(m, x, y);
			
			swap(x, y);
			int mm = 1; x[sa[1]] = 1;
			for (int i = 2; i <= n; i++)
				if (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + w] == y[sa[i - 1] + w]) x[sa[i]] = mm;
					else x[sa[i]] = ++mm;
			if (mm == n) break;
			m = mm;
		}
		
		for (int i = 1; i <= n; i++) rk[sa[i]] = i;
		int k = 0;
		for (int i = 1; i <= n; i++) {
			if (k) k--;
			int j = sa[rk[i] - 1];
			while (i + k <= n && j + k <= n && s[i + k] == s[j + k]) k++;
			height[rk[i]][0] = k;
		}
		
		for (int i = 1; i <= 20; i++)
			for (int j = 1; j + (1 << i) - 1 <= n; j++)
				height[j][i] = min(height[j][i - 1], height[j + (1 << (i - 1))][i - 1]);
		
		for (int i = 1; i <= n; i++) rt[i] = add(rt[i - 1], 1, n, sa[i]);
	}
	
	int RMQ2(int l, int r) {
		l++; int k = log2_[r - l + 1];
		return min(height[l][k], height[r - (1 << k) + 1][k]);
	}
	
	int RMQ1(int l, int r) {
		if (!l || !r) return 0;
		if (l == r) return n - l + 1;
		l = rk[l]; r = rk[r]; if (l > r) swap(l, r);
		return RMQ2(l, r);
	}
	
	int ask(int l, int r, int L, int R) {
		if (L > R) return 0;
		int rnk = rk[l];
		int l_ = rnk + 1, r_ = n, re = rnk;
		while (l_ <= r_) {
			int mid = (l_ + r_) >> 1;
			if (RMQ2(rnk, mid) >= r - l + 1) re = mid, l_ = mid + 1;
				else r_ = mid - 1;
		}
		int ansr = re;
		l_ = 1; r_ = rnk - 1; re = rnk;
		while (l_ <= r_) {
			int mid = (l_ + r_) >> 1;
			if (RMQ2(mid, rnk) >= r - l + 1) re = mid, r_ = mid - 1;
				else l_ = mid + 1;
		}
		int ansl = re;
		return query(rt[ansr], 1, n, L, R) - query(rt[ansl - 1], 1, n, L, R);
	}
}sa, sb, sab, sba, sbb;

int get_pl(int x) {
	for (int i = 0; i < fab[1].size(); i++)
		if (sum[i + 1] >= x) return i;
}

int main() {
	freopen("string.in", "r", stdin);
	freopen("string.out", "w", stdout);
	
	log2_[0] = -1; for (int i = 1; i < N; i++) log2_[i] = log2_[i >> 1] + 1;
	
	scanf("%s %s %s", a + 1, b + 1, s + 1);
	an = strlen(a + 1); bn = strlen(b + 1); n = strlen(s + 1);
	
	fsz[0] = an; fsz[1] = bn; fab[0] = "A"; fab[1] = "B";
	while (fsz[0] < 60000 || fsz[1] < 60000) {
		fsz[2] = fsz[1] + fsz[0]; fab[2] = fab[1] + fab[0];
		fsz[0] = fsz[1]; fsz[1] = fsz[2]; fab[0] = fab[1]; fab[1] = fab[2];
	}
	
	for (int i = 0; i < fab[0].size(); i++)
		if (fab[0][i] == 'A') {
			for (int j = 1; j <= an; j++) A[++An] = a[j];
		}
		else {
			for (int j = 1; j <= bn; j++) A[++An] = b[j];
		}
	for (int i = 0; i < fab[1].size(); i++)
		if (fab[1][i] == 'A') {
			for (int j = 1; j <= an; j++) B[++Bn] = a[j];
		}
		else {
			for (int j = 1; j <= bn; j++) B[++Bn] = b[j];
		}
	
	for (int i = 1; i <= n; i++) sa.s[i] = s[i];
	sa.s[n + 1] = '#';
	for (int i = 1; i <= An; i++) sa.s[n + 1 + i] = A[i];
	sa.n = n + 1 + An; sa.SA();
	for (int i = 1; i <= n; i++) sb.s[i] = s[i];
	sb.s[n + 1] = '#';
	for (int i = 1; i <= Bn; i++) sb.s[n + 1 + i] = B[i];
	sb.n = n + 1 + Bn; sb.SA();
	for (int i = 1; i <= n; i++) sab.s[i] = s[i];
	sab.s[n + 1] = '#';
	for (int i = 1; i <= An; i++) sab.s[n + 1 + i] = A[i];
	for (int i = 1; i <= Bn; i++) sab.s[n + 1 + An + i] = B[i];
	sab.n = n + 1 + An + Bn; sab.SA();
	for (int i = 1; i <= n; i++) sba.s[i] = s[i];
	sba.s[n + 1] = '#';
	for (int i = 1; i <= Bn; i++) sba.s[n + 1 + i] = B[i];
	for (int i = 1; i <= An; i++) sba.s[n + 1 + Bn + i] = A[i];
	sba.n = n + 1 + Bn + An; sba.SA();
	for (int i = 1; i <= n; i++) sbb.s[i] = s[i];
	sbb.s[n + 1] = '#';
	for (int i = 1; i <= Bn; i++) sbb.s[n + 1 + i] = B[i];
	for (int i = 1; i <= Bn; i++) sbb.s[n + 1 + Bn + i] = B[i];
	sbb.n = n + 1 + Bn + Bn; sbb.SA();
	
	fsz[0] = fsz[1 - 1]; fsz[1] = fsz[1]; fab[0] = "A"; fab[1] = "B";
	while (fsz[1] < 1000000000) {
		fsz[2] = fsz[1] + fsz[0]; fab[2] = fab[1] + fab[0];
		fsz[0] = fsz[1]; fsz[1] = fsz[2]; fab[0] = fab[1]; fab[1] = fab[2];
//		fsz[++1] = fsz[1 - 1] + fsz[1 - 2], fab[1] = fab[1 - 1] + fab[1 - 2];
	}
	sum[1] = (fab[1][0] == 'A') ? An : Bn;
	for (int i = 1; i < fab[1].size(); i++)
		sum[i + 1] = sum[i] + ((fab[1][i] == 'A') ? An : Bn);
	
	int q; scanf("%d", &q);
	while (q--) {
		int n_, L, R, l, r; scanf("%d %d %d %d %d", &n_, &L, &R, &l, &r);
		if (n_ == 0) L += bn, R += bn;//一直拆前面最后是 ba
		int bL = get_pl(L), bR = get_pl(R); ll ans = 0;
		if (bL == bR) {
			if (fab[1][bL] == 'A') ans = sa.ask(l, r, n + 1 + L - sum[bL], n + 1 + R - sum[bL] - (r - l + 1) + 1);
				else ans = sb.ask(l, r, n + 1 + L - sum[bL], n + 1 + R - sum[bL] - (r - l + 1) + 1);
		}
		else if (bR - bL == 1) {
			if (fab[1][bL] == 'A' && fab[1][bR] == 'B') ans = sab.ask(l, r, n + 1 + L - sum[bL], n + 1 + R - sum[bL] - (r - l + 1) + 1);
			if (fab[1][bL] == 'B' && fab[1][bR] == 'A') ans = sba.ask(l, r, n + 1 + L - sum[bL], n + 1 + R - sum[bL] - (r - l + 1) + 1);
			if (fab[1][bL] == 'B' && fab[1][bR] == 'B') ans = sbb.ask(l, r, n + 1 + L - sum[bL], n + 1 + R - sum[bL] - (r - l + 1) + 1);
		}
		else {
			int anum = 0, bnum = 0, abnum = 0, banum = 0, bbnum = 0;
			for (int i = bL + 1; i <= bR - 1; i++) {
				if (fab[1][i] == 'A') anum++;
					else bnum++;
				if (i + 1 <= bR - 1) {
					if (fab[1][i] == 'A' && fab[1][i + 1] == 'B') abnum++;
					if (fab[1][i] == 'B' && fab[1][i + 1] == 'A') banum++;
					if (fab[1][i] == 'B' && fab[1][i + 1] == 'B') bbnum++;
				}
			}
			if (fab[1][bL + 1] == 'A') {
				ans += sba.ask(l, r, n + 1 + L - sum[bL], sba.n);
				anum--;
			}
			else {
				if (fab[1][bL] == 'A') ans += sab.ask(l, r, n + 1 + L - sum[bL], sab.n);
					else ans += sbb.ask(l, r, n + 1 + L - sum[bL], sbb.n);
				bnum--;
			}
			if (fab[1][bR - 1] == 'A') {
				ans += sab.ask(l, r, n + 1 + 1, n + 1 + R - sum[bR - 1] - (r - l + 1) + 1);
				anum--;
			}
			else {
				if (fab[1][bR] == 'A') ans += sba.ask(l, r, n + 1 + 1, n + 1 + R - sum[bR - 1] - (r - l + 1) + 1);
					else ans += sbb.ask(l, r, n + 1 + 1, n + 1 + R - sum[bR - 1] - (r - l + 1) + 1);
				bnum--;
			}
			ans += 1ll * abnum * sab.ask(l, r, n + 1 + 1, sab.n);
			ans += 1ll * banum * sba.ask(l, r, n + 1 + 1, sba.n);
			ans += 1ll * bbnum * sbb.ask(l, r, n + 1 + 1, sbb.n);
			ans += 1ll * (anum - abnum - banum) * sa.ask(l, r, n + 1 + 1, sa.n);
			ans += 1ll * (bnum - abnum - banum - 2 * bbnum) * sb.ask(l, r, n + 1 + 1, sb.n);
		}
		printf("%lld\n", ans);
	}
	
	return 0;
}
posted @ 2023-02-22 08:33  あおいSakura  阅读(8)  评论(0编辑  收藏  举报