P5284 [十二省联考2019]字符串问题 题解

SAM + 拓扑排序

主要说一下特殊的建图方式,自己脑补了一下,发现和题解不一样

首先,建出反串 SAM,然后把每个 \(a_i\) 挂到对应的 SAM 节点上,按照长度从小到大排序

然后把每个 SAM 节点拆成 2 个 \(s, t\)\(s\) 向当前节点上挂的第一个 \(a_i\) 连边,边权是 \(|a_i|\)(长度),\(a_i\) 向后面的节点连边,边权是 \(|a_{i + 1}| - |a_{i}|\)(长度差),最后一个节点 \(a_k\)\(t\) 连边权是 \(-|a_k|\) 的边,然后 \(t\) 向原来的这个节点的儿子连边权是 0 的边

如果 \(a_i\) 能控制 \(b_j\) 直接从 \(a_i\)\(b_j\) 连边,权值为 0

找出每个 \(b_i\) 对应的节点,二分一下,找到长度 \(\ge |b_i|\) 的第一个点 \(a_i\),向它连边权为 \(|a_i|\) 的边

最后拓扑排序找最长路就行了

#include <bits/stdc++.h>
using namespace std;
#define gc getchar
#define rg register
#define I inline
#define rep(i, a, b) for(int i = a; i <= b; ++i)
#define per(i, a, b) for(int i = a; i >= b; --i)
I int read(){
	rg char ch = gc();
	rg int x = 0, f = 0;
	while(!isdigit(ch)) f |= (ch == '-'), ch = gc();
	while(isdigit(ch)) x = (x << 1) + (x << 3) + (ch ^ 48), ch = gc();
	return f ? -x : x;
}
const int N = 4e5 + 5, M = N << 2;
int tr[N][26], len[N], lst, link[N], cnt, star[N >> 1], dfn[N], bk[N];
int fa[N][19], n, na, nb, pos[N];
int head[M], ver[M << 1], nxt[M << 1], tot, edge[M << 1], deg[M];
long long f[M];
int base;
vector<int> g[N << 1];
void add(int x, int y, int z){
	swap(x, y);
	ver[++tot] = y;
	nxt[tot] = head[x];
	head[x] = tot;
	edge[tot] = z;
	++deg[y];
	//cout << y << " " << x << " " << z << endl;
}
char s[N];
struct node{
	int l, r, id;
	inline bool operator < (const node &rhs) const {
		return r - l < rhs.r - rhs.l;
	}
	node(){}
	node(int l, int r, int id): l(l), r(r), id(id) {}
};
vector<node> h[N];
I void prepro(){
	rep(i, 0, base) head[i] = 0, deg[i] = 0, f[i] = 0;
	tot = 0; base = 0;
	rep(i, 0, cnt){
		vector<node>().swap(h[i]);
		memset(tr[i], 0, sizeof tr[i]);
		vector<int>().swap(g[i]);
	}
	cnt = lst = 1;
}
I void insert(int c, int id){
	int p = lst, cur = ++cnt; lst = cur;
	len[cur] = len[p] + 1; star[id] = cur;
	while(p && !tr[p][c]) tr[p][c] = cur, p = link[p];
	if(!p) link[cur] = 1;
	else{
		int q = tr[p][c];
		if(len[q] == len[p] + 1) link[cur] = q;
		else{
			int clone = ++cnt;
			len[clone] = len[p] + 1; link[clone] = link[q]; memcpy(tr[clone], tr[q], sizeof tr[q]);
			while(p && tr[p][c] == q) tr[p][c] = clone, p = link[p];
			link[cur] = link[q] = clone;
		}
	}
}
void dfs(int x){
	rep(i, 1, 18) fa[x][i] = fa[fa[x][i - 1]][i - 1];
	for(int y : g[x]) fa[y][0] = x, dfs(y);
}
I void build(){
	per(i, n, 1) insert(s[i] - 'a', i);
	//rep(i, 2, cnt) cout << link[i] << i << endl; 
	rep(i, 2, cnt) g[link[i]].push_back(i);
	dfs(1);
}
I int getfa(int l, int r){
	int L = r - l + 1;
	int x = star[l];
	for(int i = 18; ~i; --i){
		if(L <= len[fa[x][i]]) x = fa[x][i];	
	}
	return x;
}
I long long topu(){
	queue<int> q;
	rep(i, 1, base) if(!deg[i]) q.push(i);
	int num = 0;
	while(!q.empty()){
		int x = q.front(); q.pop();
		++num;
		for(int i = head[x]; i; i = nxt[i]){
			int y = ver[i];
			f[y] = max(f[y], f[x] + edge[i]);
			--deg[y];
			if(!deg[y]) q.push(y);
		}
	}
	if(num != base) return -1;
	return f[1];
}
I void Main(){
	prepro();
	scanf("%s", s + 1); n = strlen(s + 1);
	build();
	na = read();
	int l, r;
	rep(i, 1, na){
		l = read(), r = read();
		int t = getfa(l, r);
		h[t].push_back(node(l, r, i));
	}
	base = cnt;
	rep(i, 1, cnt){
		bk[i] = base + na + i;
		for(int y : g[i]) add(bk[i], y, 0);
		sort(h[i].begin(), h[i].end());
		if(h[i].begin() != h[i].end()){
			add(i, base + h[i].begin() -> id, h[i].begin() -> r - h[i].begin() -> l + 1);
			add(base + (h[i].end() - 1) -> id, bk[i], (h[i].end() - 1) -> l - (h[i].end() - 1) -> r - 1);
			for(vector<node>::iterator it = h[i].begin(); it != h[i].end(); ++it)
				if((it + 1) != h[i].end()) add(it -> id + base, (it + 1) -> id + base, (it + 1) -> r - (it + 1) -> l - it -> r + it -> l);
		}else add(i, bk[i], 0);
	}
	nb = read(); 
	rep(i, 1, nb){
		l = read(), r = read();
		pos[i] = getfa(l, r); vector<node>::iterator it = lower_bound(h[pos[i]].begin(), h[pos[i]].end(), (node){l, r, i});
		if(it == h[pos[i]].end())  add(base + na + cnt + i, bk[pos[i]], 0);
		else add(base + na + cnt + i, base + it -> id, it -> r - it -> l + 1);
	}
	base += na + nb + cnt;
	int m = read();
	rep(i, 1, m){
		l = read(), r = read();
		add(cnt + l, cnt + na + cnt + r, 0);
	}
	printf("%lld\n", topu());
}
signed main(){
	int T = read();
	while(T--) Main();
	return 0;
}
posted @ 2020-06-06 11:06  __int256  阅读(98)  评论(0编辑  收藏  举报