[LOJ3123] CTSC2019重复

Description

给定一个⻓为 n 的字符串 s , 问有多少个⻓为 m 的字符串 t 满足:
将 t 无限重复后,可以从中截出一个⻓度为 n 且字典序比 s 小的串。

m ≤ 2000 n ≤ 2000

Solution

正难则反,补集转换,用 \(26^m\) 减去“无法从中截出字典序比 s 小的串”的方案数。

方便表述,称字符串t具有特征 \(A\) 当且仅当无法从无限重复的t中截出一段长度为m且字典序比s小的字段即A为任意无限重复的t中长度为m的字典序都比s大

考虑构造一个有限状态自动机能接受所有满足特征A的串,然后在上面计数,那么我们要统计对于每个节点开头走m条边后回到它自己的方案数(t串是无限长的)。

由于需要满足特征A,所以一个点的出边只有最大的边是有用的,因为满足A的字符串一定不会走更小的边,(要么比s大,要么目前和s一样,比s大对应的是已经接受了一个满足A的串,直接跳到根,和s一样说明要继续走下去)。

于是这就是一个只保留最大转移边的kmp自动机。

并且一个节点只有一条出边,还有许多边指向根,后者之间本质是一样的我们只要记个数即可(代码实现中是edge[i],表示i点指向根的边数)。

现在考虑如何在上面dp,不难发现这个图很特殊是一个rho,图上的路径只有两种:

  • 在环上走m步回到自己,只有当环的大小为m的约数时存在。
  • 从自己走若步(比如j步)到根,再从根走m-j步回到自己。

前者直接找环算,后者设 \(f[i][u]\) 表示从根走i步到u的方案数, \(g[i][u]\) 为从u走i步到根的方案数,dp出来后枚举j即可。

\[f[i + 1][v] \leftarrow f[i][u] \\ f[i + 1][0] \leftarrow f[i][u]\times edge[u]\\ g[i + 1][u] \leftarrow g[i][v] \]

Code

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <fstream>

typedef long long LL;
typedef unsigned long long uLL;

#define SZ(x) ((int)x.size())
#define ALL(x) (x).begin(), (x).end()
#define MP(x, y) std::make_pair(x, y)
#define DE(x) cerr << x << endl;
#define debug(...) fprintf(stderr, __VA_ARGS__)
#define GO cerr << "GO" << endl;

using namespace std;

inline void proc_status()
{
	ifstream t("/proc/self/status");
	cerr << string(istreambuf_iterator<char>(t), istreambuf_iterator<char>()) << endl;
}
inline int read() 
{
	register int x = 0; register int f = 1; register char c;
	while (!isdigit(c = getchar())) if (c == '-') f = -1;
	while (x = (x << 1) + (x << 3) + (c xor 48), isdigit(c = getchar()));
	return x * f;
}
template<class T> inline void write(T x) 
{
	static char stk[30]; static int top = 0;
	if (x < 0) { x = -x, putchar('-'); }
	while (stk[++top] = x % 10 xor 48, x /= 10, x);
	while (putchar(stk[top--]), top);
}
template<typename T> inline bool chkmin(T &a, T b) { return a > b ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, T b) { return a < b ? a = b, 1 : 0; }

const int maxN = 2e3;
const int mod = 998244353;

namespace math
{
	void pls(int &x, int y)
	{
		x += y;
		if (x >= mod) x -= mod;
		if (x < 0) x += mod;
	}
	LL qpow(LL a, LL b)
	{
		LL ans(1);
		while (b)
		{
			if (b & 1) 
				ans = ans * a % mod;
			a = a * a % mod;
			b >>= 1;
		}
		return ans;
	}
}
using math::pls;
using math::qpow;

int n, m; //n字符串长度,m走m步
char str[maxN + 2];
int fail[maxN + 2], ver[maxN + 2], edge[maxN + 2];

void insert()
{
	fail[1] = 0;
	for (int i = 2, j = 0; i <= n; ++i)
	{
		while (j and str[j + 1] != str[i]) j = fail[j];
		j += str[j + 1] == str[i];
		fail[i] = j;
	}
}

void build()
{
	for (int i = 0; i <= n; ++i)
	{
		for (int j = 25; j >= 0; --j)
		{
			int p = i;
			if (p == n) p = fail[p];
			while (p and str[p + 1] != j + 'a') p = fail[p];
			p += (str[p + 1] == (j + 'a'));
			if (p) 
			{
				ver[i] = p;
				edge[i] = 25 - j;
				break;
			}
		}
	}
}

int size;
int f[maxN + 2][maxN + 2], g[maxN + 2][maxN + 2]; // f[i][u] : root -> u cost i ; g[i][u] : u -> root cost i

void DP()
{
	f[0][0] = 1;
	for (int i = 0; i < m; ++i)
		for (int j = 0; j <= n; ++j)
		{
			pls(f[i + 1][ver[j]], f[i][j]);
			pls(f[i + 1][0], 1ll * f[i][j] * edge[j] % mod);
		}
	for (int i = 0; i <= n; ++i)
		g[1][i] = edge[i];
	for (int i = 2; i <= m; ++i)
		for (int j = 0; j <= n; ++j)
			g[i][j] = g[i - 1][ver[j]];
}

int key;
bool vis[maxN + 2];

bool dfs(int u)
{
	if (!u) return 0;
	if (vis[u]) { key = u; return 1; }
	vis[u] = 1;
	if (dfs(ver[u])) { size++; return key != u; }
	return 0;
}

int main() 
{
#ifndef ONLINE_JUDGE
	freopen("xhc.in", "r", stdin);
	freopen("xhc.out", "w", stdout);
#endif
	scanf("%d %s", &m, str + 1);
	n = strlen(str + 1);

	insert();
	build();
	DP();
	int ans = 0;
	dfs(1);
	if (m % size == 0) 
		ans = size;
	for (int i = 0; i <= n; ++i)
	{
		int sum = 0;
		for (int j = 0; j <= m; ++j)
			pls(sum, 1ll * f[j][i] * g[m - j][i] % mod);
		pls(ans, sum);
	}
	cout << ((qpow(26, m) - ans) % mod + mod) % mod << endl;
	return 0;
}
posted @ 2019-08-27 20:29  茶Tea  阅读(195)  评论(0编辑  收藏  举报