字符串

题目

Description

Input

Output

题解

显然用是AC自动机来解决

先说一下没人写的正解
二进制分组, 建\(O(lgm)\)个AC自动机。 定义AC自动机的size为这个AC自动机中的字符串个数。 当两个AC自动机size相等时合并这两个AC自动机。时间复杂度\(O(mlgm)\)

下面是比较好想, 也比较常规的做法。
题目中说强制在线, 但是却没有给字符串加密。 于是我们不一定要真的在线做。 我们可以用输入中出现的所有字符串建一个AC自动机, 虽然这样会把询问串也包含在内但对结果没有影响。
每次查询还是正常的在AC自动机上走。但由于是动态修改, 我们无法像正常的AC自动机一样预处理, 于是我们需要维护fail树。
把所有fail边拿出来建建一颗树。 然后问题转化成了: 修改一个点的权值, 询问一个点到根路径的权值和。 然后随便用数据结构维护就好了。

代码

#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cstring>
#include <algorithm>

#include <queue>

using namespace std;


typedef long long LL;


const int N = 1000010, M = 2000010;


int ed[N];


struct AC
{
	static const int SIZE = 1000010;
	
	int nxt[SIZE][26], fail[SIZE];
	LL val[SIZE];
	int root;
	int sz;

	AC() { }

	void init()
	{
		sz = 0;
		root = newnode();
	}

	int newnode()
	{
		memset(nxt[sz], 0, sizeof(nxt[sz]));
		fail[sz] = 0;
		val[sz] = 0;
		return sz++;
	}
	
	void insert(char * str, int len, int id)
	{
		int u = root;
		
		for (int i = 0; i < len; i++)
		{
			int c = str[i] - 'a';
			if (!nxt[u][c])
				nxt[u][c] = newnode();
			u = nxt[u][c];
		}
		
		ed[id] = u;
	}

	void getFail()
	{
		queue <int> q;
		
		for (int i = 0; i < 26; i++)
			if (nxt[root][i])
			{
				fail[nxt[root][i]] = root;
				q.push(nxt[root][i]);
			}
			else nxt[root][i] = root;
		
		while (!q.empty())
		{
			int x = q.front(); q.pop();
			for (int i = 0; i < 26; i++)
			{
				if (nxt[x][i])
				{
					fail[nxt[x][i]] = nxt[fail[x]][i];
					q.push(nxt[x][i]);
				}
				else nxt[x][i] = nxt[fail[x]][i];
			}
		}
	}
} ac;


struct edge
{
	int from, to;
	edge() { }
	edge(int _1, int _2) : from(_1), to(_2) { }
} edges[M];

int head[N], nxt[M], tot;

inline void init()
{
	memset(head, -1, sizeof(head));
	tot = 0;
}

inline void add_edge(int x, int y)
{
	edges[tot] = edge(x, y);
	nxt[tot] = head[x];
	head[x] = tot++;
	edges[tot] = edge(y, x);
	nxt[tot] = head[y];
	head[y] = tot++;
}


int dfn[N], idf[N], siz[N], dfs_clock;

void dfs(int x, int fa)
{
	dfn[x] = ++dfs_clock;
	idf[dfn[x]] = x;
	siz[x] = 1;
	
	for (int i = head[x]; ~i; i = nxt[i])
	{
		edge & e = edges[i];
		if (e.to != fa)
		{
			dfs(e.to, x);
			siz[x] += siz[e.to];
		}
	}
}


struct Calc
{
	LL val[N];

	void upd(int x, int v)
	{
		for (int i = x; i <= dfs_clock; i += (i & -i))
			val[i] += v;
	}

	LL qry(int x)
	{
		LL Ans = 0;
		for (int i = x; i > 0; i -= (i & -i))
			Ans += val[i];
		return Ans;
	}
} Calc;


int m;


int opt[N];


char str[N];
int L[N], R[N];


int main()
{
	scanf("%d", &m);
	
	ac.init();
	
	for (int i = 1; i <= m; i++)
	{
		scanf("%d", &opt[i]);
		L[i] = R[i-1] + 1;
		scanf("%s", str + L[i]);
		R[i] = L[i] + strlen(str + L[i]) - 1;
		ac.insert(str + L[i], R[i] - L[i] + 1, i);
	}
	
	ac.getFail();
	
	init();
	
	for (int i = 1; i < ac.sz; i++)
		add_edge(ac.fail[i], i);
	
	dfs(0, -1);
	
	int mask = 0;
	
	for (int i = 1; i <= m; i++)
	{
		opt[i] ^= mask;
		if (opt[i] == 1)
		{
			int x = ed[i];
			int l = dfn[x], r = dfn[x] + siz[x] - 1;
			Calc.upd(l, 1);
			Calc.upd(r + 1, -1);
		}
		else if (opt[i] == 2)
		{
			int x = ed[i];
			int l = dfn[x], r = dfn[x] + siz[x] - 1;
			Calc.upd(l, -1);
			Calc.upd(r + 1, 1);
		}
		else
		{
			LL Ans = 0;
			int u = ac.root;
			for (int j = L[i]; j <= R[i]; j++)
			{
				u = ac.nxt[u][str[j] - 'a'];
				Ans += Calc.qry(dfn[u]);
			}
			mask ^= abs(Ans);
			printf("%lld\n", Ans);
		}
	}
	
	return 0;
}

下面是难写还慢的要命的二进制分组

#include <iostream>
#include <cstring>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
 
#include <queue>
 
using namespace std;
 
 
typedef long long LL;
 
 
const int N = 1000010;
 
 
char str[N]; int L[N], R[N];
 
 
namespace mempool
{
	const int SIZE = 5000010;
 
	int nxt[SIZE][26], fail[SIZE];
	LL val[SIZE];
	int sz;
 
	int newNode() { return ++sz; }
	
	void deleteNode(int x)
	{
		memset(nxt[x], 0, sizeof(nxt[x]));
		fail[x] = 0;
		val[x] = 0;
	}
}
 
 
struct AC
{
	vector <pair <int, int>> vec;
	int size;
	int root;
 
	void remove(int u)
	{
		using namespace mempool;
		if (!u) return;
		for (int i = 0; i < 26; i++)
			if (nxt[u][i]) remove(nxt[u][i]);
		mempool :: deleteNode(u);
	}
 
	inline void clear()
	{
		vec.clear();
		remove(root);
		size = 0;
	}
 
	void Insert(int l, int r)
	{
		using namespace mempool;
		
		if (root == 0) { puts("error"); exit(0); }
		
		int u = root;
		for (int i = l; i <= r; i++)
		{
			int c = str[i] - 'a';
			if (!nxt[u][c]) nxt[u][c] = newNode();
			u = nxt[u][c];
		}
		val[u]++;
	}
 
	void insert(int l, int r) { size++; vec.push_back(make_pair(l, r)); }
 
	void build()
	{
		using namespace mempool;
		
		remove(root);
		root = newNode();
		for (auto v : vec) Insert(v.first, v.second);
		
		queue <int> q;
		
		fail[root] = root;
		for (int i = 0; i < 26; i++)
			if (nxt[root][i])
			{
				fail[nxt[root][i]] = root;
				q.push(nxt[root][i]);
			}
		
		while (!q.empty())
		{
			int x = q.front(); q.pop();
			for (int i = 0; i < 26; i++)
				if (nxt[x][i])
				{
					int p = fail[x];
					while (p != root && !nxt[p][i]) p = fail[p];
					fail[nxt[x][i]] = nxt[p][i] ? nxt[p][i] : root;
					val[nxt[x][i]] += val[fail[nxt[x][i]]];
					q.push(nxt[x][i]);
				}
		}
	}
 
	LL query(int l, int r)
	{
		using namespace mempool;
		
		int u = root;
		LL Ans = 0;
		for (int i = l; i <= r; i++)
		{
			int c = str[i] - 'a';
			while (u != root && !nxt[u][c]) u = fail[u];
			u = nxt[u][c] ? nxt[u][c] : root;
			Ans += val[u];
		}
		return Ans;
	}
};
 
AC add[N], del[N]; int top1, top2;
 
void Add(int l, int r)
{
	add[++top1].clear();
	add[top1].insert(l, r);
	while (top1 > 1 && add[top1-1].size == add[top1].size)
	{
		for (auto v : add[top1].vec) add[top1-1].insert(v.first, v.second);
		add[top1--].clear();
	}
	add[top1].build();
}

void Del(int l, int r)
{
	del[++top2].clear();
	del[top2].insert(l, r);
	while (top2 > 1 && del[top2-1].size == del[top2].size)
	{
		for (auto v : del[top2].vec) del[top2-1].insert(v.first, v.second);
		del[top2--].clear();
	}
	del[top2].build();
}
 
LL Qry(int l, int r)
{
	LL Ans = 0;
	for (int i = 1; i <= top1; i++)
		Ans += add[i].query(l, r);
	for (int i = 1; i <= top2; i++)
		Ans -= del[i].query(l, r);
	return Ans;
}


int n;
 
 
int main()
{
	scanf("%d", &n);
	
	LL mask = 0;
	
	for (int i = 1; i <= n; i++)
	{
		int opt;
		scanf("%d", &opt);
		opt ^= mask;
		scanf("%s", str + R[i-1] + 1);
		L[i] = R[i-1] + 1;
		R[i] = L[i] + strlen(str + L[i]) - 1;
		
		if (opt == 1) Add(L[i], R[i]);
		else if (opt == 2) Del(L[i], R[i]);
		else
		{
			LL Ans = Qry(L[i], R[i]);
			mask ^= abs(Ans);
			printf("%lld\n", Ans);
		}
	}
	
	return 0;
}
posted @ 2019-08-07 08:56  EZ_WYC  阅读(129)  评论(0编辑  收藏  举报