「字符串算法」第4章 字典树课堂过关

「字符串算法」第4章 字典树课堂过关

YbtOJ又双叒叕炸掉了

前几分钟还好好的

由于YbtOJ已经炸裂,暂时无法测评,现采用与网络标程对拍的方式验证程序的正确性.

数据生成程序代码放在文章末尾

关于字典树

以前写的:

模板题

由于找不到最直接的模板,就拿了一个最裸的题权当模板

传送门

大体思路

应用&结构:用于实现字符串快速检索的多叉树结构

总思路其他博客已经很详尽,这里不再赘述(其实懒得画图)

定义&初始化

定义trie[SIZE][30](假设只有小写字母),trie[i][j]表示当前在i结点,编号为j的子结点所处的位置,(我们称字符'a'的编号为0,'b'为1,以此类推),即trie是一个用于模拟指针的数组,定义一个特殊的空结点(一般为0),所有的指针均指向空

定义end[SIZE],end[i]表示下标为i的结点是否为某个字符串的终点

插入

void insert(char *s , int siz) {
	static int top = 1;//trie的第一维最大下标,类似于链式前向星
	int p = 1;
	for(int i = 1 ; i <= siz ; i++) {
		int c = s[i] - 'a';
		if(trie[p][c] == 0)//如果指向空,则新建结点
			trie[p][c] = ++top;
		p = trie[p][c];
	}
	vis[p] = false;
	end[p] = true;
}

查找(以模板为例)

int search(char *s , int siz) {
	int p = 1;
	for(int i = 1 ; i <= siz ; i++) {
		p = trie[p][s[i] - 'a'];
		if(p == 0) return 0;//当前字符串在trie树中不存在
	}
	if(end[p] == false) return 0;//WRONG
	if(vis[p] == true)	return 2;//REPEAT
	vis[p] = true;//标记当前字符串已经访问过
	return 1;//OK
}

模板题完整代码

#include <iostream>
#include <cstdio>
#define nn 500010
using namespace std;
int sread(char *s) {
	int siz = 1;
	do
		s[siz] = getchar();
	while(s[siz] < 'a' || s[siz] > 'z');
	while(s[siz] >= 'a' && s[siz] <= 'z')
		s[++siz] = getchar();
	--siz;
	return siz;
}

bool vis[nn] , end[nn];
int trie[nn][30];
void insert(char *s , int siz) {
	static int top = 1;
	int p = 1;
	for(int i = 1 ; i <= siz ; i++) {
		int c = s[i] - 'a';
		if(trie[p][c] == 0)
			trie[p][c] = ++top;
		p = trie[p][c];
	}
	vis[p] = false;
	end[p] = true;
}
int search(char *s , int siz) {
	int p = 1;
	for(int i = 1 ; i <= siz ; i++) {
		p = trie[p][s[i] - 'a'];
		if(p == 0) return 0;
	}
	if(end[p] == false) return 0;//WRONG
	if(vis[p] == true)	return 2;//REPEAT
	vis[p] = true;
	return 1;//OK
}

int n , m;
char s[nn];
int main() {
	cin >> n;
	for(int i = 1 ; i <= n ; i++) {
		int siz = sread(s);
		insert(s , siz);
	}
	cin >> m;
	for(int i = 1 ; i <= m ; i++) {
		int siz = sread(s);
		int res = search(s , siz);
		if(res == 0)
			puts("WRONG");
		else if(res == 1)
			puts("OK");
		else
			puts("REPEAT");
			
	}
	return 0;
}

A. 【例题1】前缀统计

题目

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#define nn 1000010
#define TrieRoot 1
using namespace std;
int trie[nn][30];
int end[nn];
void insert(char *s) {
	static int cnt = 2;
	int p = TrieRoot , len = strlen(s);
	for(int i = 0 ; i < len ; i++) {
		if(trie[p][s[i] - 'a'] == 0)
			trie[p][s[i] - 'a'] = cnt++;
		p = trie[p][s[i] - 'a'];
	}
	++end[p];
}
int solve(char *s) {
	int len = strlen(s);
	int p = TrieRoot;
	int ans = 0;
	for(int i = 0 ; i < len ; i++) {
		ans += end[p];
		p = trie[p][s[i] - 'a'];
	}
	ans += end[p];
	return ans;
}

int n , m;
char s[nn];
int main() {
	scanf("%d %d" , &n , &m);
	for(int i = 1 ; i <= n ; i++) {
		scanf("%s" , s);
		insert(s);
	}
	for(int i = 1 ; i <= m ; i++) {
		scanf("%s" , s);
		printf("%d\n" , solve(s));
	}
	return 0;
}

B. 【例题2】最大异或对

题目

代码

#include <iostream>
#include <cstdio>
#define int unsigned
#define nn 100010 * 30
#define TrieRoot 1
using namespace std;
int read() {
	int re = 0;
	char c = getchar();
	while(c < '0' || c > '9')
		c = getchar();
	while(c >= '0' && c <= '9')
		re = (re << 1) + (re << 3) + c - '0',
		c = getchar();
	return re;
}

int trie[nn][2];
int rev(int x) {
	int res = 0;
	for(int i = 1 ; i <= 31 ; i++) {
		res = (res << 1) + (x & 1);
		x >>= 1;
	}
	return res;
}
void insert(int x) {
	static int cnt = TrieRoot + 1;
	int p = TrieRoot;
	for(int i = 1 ; i <= 31 ; i++) {
		int tmp = (x & 1);
		if(trie[p][tmp] == 0)
			trie[p][tmp] = cnt++;
		p = trie[p][tmp];
		x >>= 1;
	}
}

int n;
int a[100010];
signed main() {
	n = read();
	for(int i = 1 ; i <= n ; i++) {
		a[i] = rev(read());
		insert(a[i]);
	}
	
	int ans = 0;
	for(int i = 1 ; i <= n ; i++) {
		int res = 0;
		int tmp = a[i];
		int p = TrieRoot;
		for(int j = 1 ; j <= 31 ; j++) {
			if(trie[p][(tmp & 1) ^ 1] != 0)
				p = trie[p][(tmp & 1) ^ 1] , res = (res << 1) + 1;
			else
				p = trie[p][tmp & 1] , res = (res << 1);
			tmp >>= 1;
		}
		if(res > ans)
			ans = res;
	}
	cout << ans;
	return 0;
}

C. 【例题3】最长异或路径

题目

思路 & 代码

以前写的一篇博客

题目

传送门

思路

别在意这是一道紫题,其实还是能做的

首先要知道:异或运算满足交换律,结合律,\(a\ xor\ a = 0\),一个点A到另一个点B的异或路径长度等于(A到C的异或路径长度 xor B到C的异或路径长度),其中C为任一点

为什么?

假设C是树的根,后者只是比前者多跑了2遍C到lca(A,B)的路径,也就是这条路径上的边会被异或两边,又因为同一个数异或的结果为0,所以这多跑的2遍对结果无影响

所以,我们随便选一个点作为根(这里就用1号点),求出所有点到1号点的异或路径长度,存在dis[]中,这样,我们就能\(O(1)\)求出两个点之间的异或路径长度

到此,原问题转化为:

找一对i,j,使dis[i] ^ dis[j]最大(" ^ "表示异或)

01trie是解决这种异或问题的利器,但是,怎么找呢?

先说01trie:按照dis[i]从二进制下高位到低位,从根到叶子的顺序建树(懒得画图了,自己看代码理解下)

然后?

我在没看题解时的思路:

  1. 从trie的根结点开始向下找,直到遇到分支(因为此时高位是1,高位大的一定大)
  2. 找到分支后,用BFS+贪心查找最优解(尽量让两个数异或后高位为1)
  3. 但是,最坏情况下,时间复杂度是可以去到\(2^{30}\)

因此,看了一波题解

正解:

  1. \(O(n)\)枚举每一个\(dis_i\)
  2. \(O(30)\)在trie中贪心查找另一个\(trie_j\),使trie[i] ^ trie[j]最大(这里的贪心其实就是让异或出来的结果高位更大,这也就决定了如何建trie树)

反思

其实我的思路离正解已经很近了,可以说只差了最后一步,但是失之毫厘差之千里,复杂的几乎就是\(O(n^2)\)的纯暴力和正解的区别,应该从多方面思考问题的解,优化程序中复杂度最高的地方

代码

#include <iostream>
#include <cstdio>
#define nn 100010
using namespace std;
int read() {
	int re = 0 , sig = 1;
	char c = getchar();
	while(c < '0' || c > '9') {
		if(c == '-')sig = -1;
		c = getchar();
	}
	while(c >= '0' && c <= '9')
		re = (re << 1) + (re << 3) + c - '0',
		c = getchar();
	return re * sig;
}
struct ednode{//链式前向星
	int nxt , w , to;
}ed[nn * 2];
int head[nn];
inline void addedge(int u , int v , int w) {
	static int top = 1;
	ed[top].to = v , ed[top].w = w , ed[top].nxt = head[u] , head[u] = top;
	++top;
}

int dis[nn];
int n;
int trie[nn * 30][3];


void dfs(int x , int pre) {//处理出dis数组
	for(int i = head[x] ; i ; i = ed[i].nxt) {
		if(ed[i].to == pre)continue;
		dis[ed[i].to] = dis[x] ^ ed[i].w;
		dfs(ed[i].to , x);
	}
}
void build() {//建trie树
	int top = 1;
	for(int i = 1 ; i <= n ; i++) {
		int tmp = dis[i];
		int p = 1;
		for(int j = 30 ; j >= 0 ; j--) {
			int x = (tmp >> j) & 1;
			if(trie[p][x] == 0)
				trie[p][x] = ++top;
			p = trie[p][x];
		}
	}
}
int GetAns() {
	int ans = 0;
	for(int i = 1 ; i <= n ; i++) {//枚举每一个dis
		int tmp = dis[i];
		int res = 0;
		int p = 1;
		for(int j = 30 ; j >= 0 ; j--) {//找到最优的另一个dis,满足它和dis[i]的异或值最大
			if(trie[p][!((tmp >> j) & 1)] != 0) {
				res += (1 << j);
				p = trie[p][!((tmp >> j) & 1)];
			}
			else
				p = trie[p][(tmp >> j) & 1];
		}
		if(res > ans)
			ans = res;
	}
	return ans;
}

int main() {
	n = read();
	for(int i = 1 ; i < n ; i++) {
		int u , v , w;
		u = read();	v = read();	w = read();
		addedge(u , v , w);
		addedge(v , u , w);
	}
	dfs(1 , 0);
	build();
	cout << GetAns();
	return 0;
}
/*洛谷样例2
10
1 2 12188248
2 3 2060207469
1 4 960096258
1 5 681126748
3 6 719580677
6 7 2084644229
4 8 730246277
1 9 668729523
9 10 1055107866

2084644229

*/

D. 【例题4】阅读理解

题目

传送门(洛谷)

思路

很简单的一道题(别看是蓝的)

对于字典树的每一个节点,捆绑一个\(head\)指针,用类似链式前向星的方式存储该单词所在的文章,如代码:

int trie[nn][30];
int head[nn] , nxt[nn] , dat[nn];
inline void insert(char *s , int article) {
	static int cnt = TrieRoot + 1;
	int len = strlen(s);
	int p = TrieRoot;
	for(int i = 0 ; i < len ; i++) {
		if(trie[p][s[i] - 'a'] == 0)
			trie[p][s[i] - 'a'] = cnt++;
		p = trie[p][s[i] - 'a'];
	}
	//单词插入完毕
	static int cnt2 = 1;
	for(int i = head[p] ; i ; i = nxt[i])//检查是否重复,其实不用循环好像也可以(个人没有验证)
		if(dat[i] == article)
			return;
	dat[cnt2] = article , nxt[cnt2] = head[p] , head[p] = cnt2;//将当前article插入到链中
	++cnt2;
}

由于我们遍历文章的顺序是从1到\(n\),所以链中的文章一定是倒序的,用递归输出即可(详见完整代码)

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#define nn 1000010
#define TrieRoot 1
using namespace std;
int trie[nn][30];
int head[nn] , nxt[nn] , dat[nn];


inline void insert(char *s , int article) {
	static int cnt = TrieRoot + 1;
	int len = strlen(s);
	int p = TrieRoot;
	for(int i = 0 ; i < len ; i++) {
		if(trie[p][s[i] - 'a'] == 0)
			trie[p][s[i] - 'a'] = cnt++;
		p = trie[p][s[i] - 'a'];
	}
	
	static int cnt2 = 1;
	for(int i = head[p] ; i ; i = nxt[i])
		if(dat[i] == article)
			return;
	dat[cnt2] = article , nxt[cnt2] = head[p] , head[p] = cnt2;
	++cnt2;
}
inline void print(int p) {
	if(p == 0)	return;
	print(nxt[p]);
	printf("%d " , dat[p]);
}
inline void solve(char *s) {
	int p = TrieRoot;
	int len = strlen(s);
	for(int i = 0 ; i < len ; i++)
		p = trie[p][s[i] - 'a'];
	print(head[p]);
}

int n , m;
char s[110];
int main() {
	scanf("%d" , &n);
	for(int i = 1 ; i <= n ; i++) {
		int L;
		scanf("%d" , &L);
		for(int j = 1 ; j <= L ; j++) {
			scanf("%s" , s);
			insert(s , i);
		}
	}
	scanf("%d" , &m);
	for(int i = 1 ; i <= m ; i++) {
		scanf("%s" , s);
		solve(s);
		putchar('\n');
	}
	return 0;
}

随机数据生成

A. 【例题1】前缀统计

#include <bits/stdc++.h>
using namespace std;
int random(int r , int l = 1) {
	return (long long)rand() * rand() * rand() % (r - l + 1) + l;
}
int main() {
	unsigned seed;
	cin >> seed;
	seed *= time(0);
	srand(seed);
	
	int n = random(1000) , m = random(1000);
	printf("%d %d\n" , n , m);
	for(int i = 1 ; i <= n ; i++) {
		int len = random(10);
		while(len--)
			putchar(random('z' , 'a'));
		putchar('\n');
	}
	for(int i = 1 ; i <= m ; i++) {
		int len = random(1e6 / n);
		while(len--)
			putchar(random('z' , 'a'));
		putchar('\n');
	}
	return 0;
}

B. 【例题2】最大异或对

#include <bits/stdc++.h>
using namespace std;
int random(int r , int l = 1) {
	return (long long)rand() * rand() * rand() % (r - l + 1) + l;
}
int main() {
	unsigned seed;
	cin >> seed;
	seed *= time(0);
	srand(seed);
	
	int n = random(1e5);
	cout << n << '\n';
	for(int i = 1 ; i <= n ; i++) {
		printf("%d " , random((1u << 31) - 1) );
	}
	return 0;
}

D. 【例题4】阅读理解

#include <bits/stdc++.h>
using namespace std;
int random(int r , int l = 1) {
	return (long long)rand() * rand() * rand() % (r - l + 1) + l;
}
char s[100010][30];
int main() {
	unsigned seed;
	cin >> seed;
	seed *= time(0);
	srand(seed);
	
	int wordnum = 1e5;
	for(int i = 1 ; i <= wordnum ; i++) {
		int len = random(20);
		for(int j = 0 ; j < len ; j++)
			s[i][j] = random('z' , 'a');
	}
	
	int n = random(1e3) , m = random( min(wordnum , (int)1e4));
	printf("%d\n" , n);
	for(int i = 1 ; i <= n ; i++) {
		int len = random(100);
		printf("%d " , len);
		for(int j =1 ; j <= len ; j++) {
			printf("%s " , s[random(wordnum)]);
		}
	}
	printf("%d\n" , m);
	for(int i = 1 ; i <= m ; i++) {
		puts(s[i]);
	}
	return 0;
}
posted @ 2021-04-03 10:54  追梦人1024  阅读(70)  评论(0编辑  收藏  举报