[BZOJ3605] 括号序列计数

[BZOJ3605] 括号序列计数

题意

全网第一篇题解qaq

就给你个自动机,然后问你从 \(s\) 节点到 \(t\) 节点有多少合法长度为 \(l\) 的括号串

节点数 \(V\) 不超过 \(2\)

题解

我们直接用一个矩阵描述一个状态即可。这里就是 \(mat[s][t]\) 表示 \(s\)\(t\) 合法括号串的方案数。

我们考虑这玩意怎么转移?矩阵乘法。没错。

那么我们令 \(ans[T][V][V]\) 表示答案矩阵,查询只要预处理出来矩阵 \(O(1)\) 询问即可。

那么考虑怎么处理啊。我不会啊。

\(ansL[T][V][V]\) 表示 \(ans[T-1]\) 右边添加了一个左括号的方案数,同理可以得到 \(ansR[T][V][V]\)

那么假如我们知道了 \(ans[1\sim i][V][V]\)

\[ans[i+1]=\sum_{j=1}^{i}ansL[j]\times ansR[i+1-j] \]

这玩意

相当于是

\[ans[i+1]=ans[i]\times B+\sum_{j=1}^ians[j-1]\times L\times ans[i-j]\times R \]

发现后边是个卷积形式,那么就是半在线卷积的板子了(似乎也不板子

啥叫半在线卷积呢

蒟蒻解释一下

就是这种 \(f[i]=\sum_{j=1}^{i-1}f[j]\times g[i-j]\)

\(g\) 要在知道 \(f\) 之后才知道的

类似的问题,我们采取 CDQ 的分治即可,但是会发现其实有的贡献没算到

  1. \(l=0\) 的时候 \([l,mid]\times [l,mid]\)
  2. \(l\not = 0\) 的时候 \([l,mid]\times [1,r-l]\) (两个都要)


#include <bits/stdc++.h>
const int MOD = 998244353, N = 1e5 + 10, INF = 0x3f3f3f3f, _ = 1 << 18, G = 3, IG = (MOD + 1) / 3; 
const double eps = 1e-9, alpha = 0.7;
template<typename T>inline T read() { T x = 0; bool f = 0; char ch = getchar(); while (!isdigit(ch)) { f = ch == '-'; ch = getchar(); } while (isdigit(ch)) { x = x * 10 + ch - '0'; ch = getchar(); } return f ? -x : x; }
template<typename T>inline T max(const T &x, const T &y) { return x > y ? x : y; }
template<typename T>inline T min(const T &x, const T &y) { return x < y ? x : y; }
template<typename T>inline T abs(const T &x) { return x > 0 ? x : -x; }
inline int Mod(int x) { if (x >= MOD) { return x - MOD; } else if (x < 0) { return x + MOD; } else { return x; } }
template<typename T1, typename T2>struct pair;
template<typename T1, typename T2>
inline pair<T1, T2> make_pair(const T1 &x, const T2 &y) { return pair<T1, T2>(x, y); }
template<typename T1, typename T2>
struct pair {
T1 first;
T2 second;
pair(const T1 &x = 0, const T2 &y = 0) { first = x; second = y; }
friend pair operator + (const pair<T1, T2> &x, const pair<T1, T2> &y) { return pair(x.first + y.first, x.second + y.second); }
template<typename T>
friend pair operator * (const pair<T1, T2> &x, const T &y) { return make_pair(x.first * y, x.second * y); }
friend pair operator - (const pair<T1, T2> &x, const pair<T1, T2> &y) { return pair(x.first - y.first, x.second - y.second); }
friend bool operator < (const pair<T1, T2> &x, const pair<T1, T2> &y) { return x.first < y.first || (x.first == y.first && x.second < y.second); }
friend bool operator <= (const pair<T1, T2> &x, const pair<T1, T2> &y) { return x.first < y.first || (x.first == y.first && x.second <= y.second); }
friend bool operator > (const pair<T1, T2> &x, const pair<T1, T2> &y) { return x.first > y.first || (x.first == y.first && x.second > y.second); }
friend bool operator >= (const pair<T1, T2> &x, const pair<T1, T2> &y) { return x.first > y.first || (x.first == y.first && x.second >= y.second); }
friend bool operator == (const pair<T1, T2> &x, const pair<T1, T2> &y) { return x.first == y.first && x.second == y.second; }
friend bool operator != (const pair<T1, T2> &x, const pair<T1, T2> &y) { return x.first != y.first || x.second != y.second; }
};
template<typename T>
struct stack { int top; T vals[N]; bool empty() { return !top; } void push(T x) { vals[++top] = x; } void pop() { if (top > 0) { --top; } return; } void clear() { top = 0; } T TOP() { return top ? vals[top] : T(-1); } };
inline int ksm(int x, int y) { int ret = 1; for ( ; y; y /= 2, x = 1LL * x * x % MOD) { if (y & 1) { ret = 1LL * ret * x % MOD; } } return ret; }
inline int ksc(int x, int y) { int ret = 0; for ( ; y; y /= 2, x = Mod(x + x)) { if (y & 1) { ret = Mod(ret + x); } } return ret; }
struct graph { int cnt, h[N]; pair<int, int> edge[N * 2]; void add_edge(int x, int y) { edge[cnt].first = y; edge[cnt].second = h[x]; h[x] = cnt++; } void clear() { memset(h, -1, sizeof h); cnt = 0; } };
inline int ls(int k) { return k << 1; }
inline int rs(int k) { return k << 1 | 1; }
inline int sign(double x) { if (fabs(x) < eps) { return 0; } else if (x < 0) { return -1; } else { return 1; } }
template<typename T>
struct BIT { int limit; T vals[N]; void resize(int n) { limit = n; } void add(int x, T y) { for ( ; x <= limit; x += x & -x) { vals[x] += y; } } T sum(int x) { T ret = 0; for ( ; x; x -= x & -x) { ret = ret + vals[x]; } return ret; } };
inline void flush() { std::cout << std::endl; }
template<typename T>
inline void readin(T *elem, int size) { for (int i = 1; i <= size; ++i) { std::cin >> elem[i]; } }
using std::set;
using std::map;
using std::vector;
using std::ios;
using std::cin;
using std::cout;
using std::endl;
using std::queue;
using std::cerr;
#define space ' '
#define enter '\n'
#define orz_1
#define orz_2
struct mat{
	int arr[2][2];
	int* operator [](int x){return arr[x];}
	mat() {
		memset(arr, 0, sizeof arr);
	}
	friend mat operator *(mat p , mat q){
		mat x;
		for(int i = 0 ; i < 2 ; ++i)
			for(int j = 0 ; j < 2 ; ++j)
				for(int k = 0 ; k < 2 ; ++k)
					x[i][k] = Mod(x[i][k] + 1LL * p[i][j] * q[j][k] % MOD);
		return x;
	}
	friend mat operator *(mat p , int q){
		for(int i = 0 ; i < 2 ; ++i) for(int j = 0 ; j < 2 ; ++j) p[i][j] = 1LL * p[i][j] * q % MOD; return p;
	}
	friend void operator +=(mat &p , mat q){
		for(int i = 0 ; i < 2 ; ++i) for(int j = 0 ; j < 2 ; ++j) p[i][j] = Mod(p[i][j] + q[i][j]);
	}
}L , R , B; int V , Q, w[_];

int dir[_] , need;
void init(int len){
	need = 1; while(need < len) need <<= 1;
	for(int i = 1 ; i < need ; ++i) dir[i] = (dir[i >> 1] >> 1) | (i & 1 ? need >> 1 : 0);
	static int L = 1;
	for(int &i = L ; i < need ; i <<= 1){
		w[i] = 1;
		int wn = ksm(G, (MOD - 1) / (i * 2));
		for (int j = 1; j < i; ++j) {
			w[i+ j] = w[i + j - 1] * 1LL * wn % MOD;
		}
	}
}

void DFT(mat *arr , int tp){
	for(int i = 1 ; i < need ; ++i) if(i < dir[i]) std::swap(arr[i] , arr[dir[i]]);
	for(int i = 1 ; i < need ; i <<= 1)
		for(int j = 0 ; j < need ; j += i << 1)
			for(int k = 0 ; k < i ; ++k){
				mat x = arr[j + k] , y = arr[i + j + k] * w[i + k];
				for(int p = 0 ; p < 2 ; ++p)
					for(int q = 0 ; q < 2 ; ++q){
						int m = x[p][q], n = y[p][q];
						arr[j + k][p][q] = Mod(m + n); arr[i + j + k][p][q] = Mod(m - n);
					}
			}
	if(tp == -1){
		std::reverse(arr + 1 , arr + need);
		int inv = ksm(need, MOD - 2);
		for(int i = 0 ; i < need ; ++i){
			for(int j = 0 ; j < 2 ; ++j)
				for(int k = 0 ; k < 2 ; ++k) arr[i][j][k] = 1LL * arr[i][j][k] * inv % MOD;;
		}
	}
}

mat ans[_] , ansL[_] , ansR[_] , tmp1[_] , tmp2[_];
void solve(int l , int r){
	if(l == r){
		if(!l) ans[0][0][0] = ans[0][1][1] = 1; else ans[l] += ans[l - 1] * B;
		ansL[l] = ans[l] * L; ansR[l] = ans[l] * R;
		return;
	}
	int mid = (l + r) >> 1 , cnt1 = 0 , cnt2 = 0; solve(l , mid); tmp1[0] = tmp2[0] = mat();
	if(!l){
		for(int i = l ; i <= mid ; ++i){tmp1[++cnt1] = ansL[i]; tmp2[++cnt2] = ansR[i];}
		init(r + 1);
		for(int i = cnt1 + 1 ; i < need ; ++i) tmp1[i] = mat();
		for(int i = cnt2 + 1 ; i < need ; ++i) tmp2[i] = mat();
		DFT(tmp1 , 1); DFT(tmp2 , 1); for(int i = 0 ; i < need ; ++i) tmp1[i] = tmp1[i] * tmp2[i];
		DFT(tmp1 , -1); for(int i = mid + 1 ; i <= r ; ++i) ans[i] += tmp1[i - l];
	}
	else{
		for(int i = l ; i <= mid ; ++i) tmp1[++cnt1] = ansL[i];
		for(int i = 0 ; i < r - l ; ++i) tmp2[++cnt2] = ansR[i];
		init(r - l + 1);
		for(int i = cnt1 + 1 ; i < need ; ++i) tmp1[i] = mat();
		for(int i = cnt2 + 1 ; i < need ; ++i) tmp2[i] = mat();
		DFT(tmp1 , 1); DFT(tmp2 , 1); for(int i = 0 ; i < need ; ++i) tmp1[i] = tmp1[i] * tmp2[i];
		DFT(tmp1 , -1); for(int i = mid + 1 ; i <= r ; ++i) ans[i] += tmp1[i - l];
		cnt1 = cnt2 = 0; tmp1[0] = tmp2[0] = mat();
		for(int i = l ; i <= mid ; ++i) tmp1[++cnt1] = ansR[i];
		for(int i = 0 ; i < r - l ; ++i) tmp2[++cnt2] = ansL[i];
		for(int i = cnt1 + 1 ; i < need ; ++i) tmp1[i] = mat();
		for(int i = cnt2 + 1 ; i < need ; ++i) tmp2[i] = mat();
		DFT(tmp1 , 1); DFT(tmp2 , 1); for(int i = 0 ; i < need ; ++i) tmp1[i] = tmp2[i] * tmp1[i];
		DFT(tmp1 , -1); for(int i = mid + 1 ; i <= r ; ++i) ans[i] += tmp1[i - l];
	}
	solve(mid + 1 , r);
}

int S[15] , T[15] , K[15];

int main(){
	freopen("dfa.in", "r", stdin);
	freopen("dfa.out", "w", stdout);
	cin >> V;
	for(int i = 0 ; i < V ; ++i){
		int x , y; cin >> x >> y; L[i][x] = y % MOD;
		cin >> x >> y; R[i][x] = y % MOD;
		cin >> x >> y; B[i][x] = y % MOD;
	}
	int mx = 0;
	cin >> Q; for(int i = 1 ; i <= Q ; ++i){cin >> S[i] >> T[i] >> K[i]; mx = max(mx , K[i]);}
	solve(0 , mx);
	for(int i = 1 ; i <= Q ; ++i) cout << ans[K[i]][S[i]][T[i]] << endl;
	return 0;
}

http://121.17.168.211:30000/contest/62

https://darkbzoj.tk/problem/3605

posted @ 2021-07-09 23:18  siriehn_nx  阅读(129)  评论(0编辑  收藏  举报