BZOJ4598 [Sdoi2016]模式字符串 【点分治 + hash】
题目
给出n个结点的树结构T,其中每一个结点上有一个字符,这里我们所说的字符只考虑大写字母A到Z,再给出长度为m
的模式串s,其中每一位仍然是A到z的大写字母。Alice希望知道,有多少对结点<u,v>满足T上从u到V的最短路径
形成的字符串可以由模式串S重复若干次得到?这里结点对<u,v>是有序的,也就是说<u,v>和<v,u>需要被区分.
所谓模式串的重复,是将若干个模式串S依次相接(不能重叠).例如当S=PLUS的时候,重复两次会得到PLUSPLUS,
重复三次会得到PLUSPLUSPLUS,同时要注恿,重复必须是整数次的。例如当S=XYXY时,因为必须重复整数次,所以X
YXYXY不能看作是S重复若干次得到的。
输入格式
每一个数据有多组测试,
第一行输入一个整数C,表示总的测试个数。
对于每一组测试来说:
第一行输入两个整数,分别表示树T的结点个数n与模式长度m。结点被依次编号为1到n,
之后一行,依次给出了n个大写字母(以一个长度为n的字符串的形式给出),依次对应树上每一个结点上的字符(
第i个字符对应了第i个结点).
之后n-1行,每行有两个整数u和v表示树上的一条无向边,之后一行给定一个长度为m的由大写字母组成的字符串,
为模式串S。
1<=C<=10,3<=N<=10000003<=M<=1000000
输出格式
给出C行,对应C组测试。每一行输出一个整数,表示有多少对节点<u,v>满足从u到v的路径形成的字符串恰好是模
式串的若干次重复.
输入样例
1
11 4
IODSSDSOIOI
1 2
2 3
3 4
1 5
5 6
6 7
3 8
8 9
6 10
10 11
SDOI
输出样例
5
提示
数据文件太过巨大,仅提供前三组数据测试.
题解
BZOJ数据较小,卡过了
但洛谷似乎T得不行
我们预处理出字符串前i个和后i个的hash值【这里\(i<=n\)处理的字符串由原字符串复制多次形成】
然后点分
对于每棵子树,进行遍历,记录当前到根的hash值,如果匹配上了前缀或者后缀,查找f[i]或者g[i]表示长度对m取模后为i的到根路径为原字符串前缀或后缀的路径数,更新答案
常熟略大,,弱弱卡过
#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#define LL long long int
#define Redge(u) for (int k = h[u],to; k; k = ed[k].nxt)
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define ULL unsigned long long int
#define cls(s) memset(s,0,sizeof(s))
#define BUG(s,n) for (int i = 1; i <= (n); i++) cout<<s[i]<<' '; puts("");
using namespace std;
const int maxn = 1000005,maxm = 2000005,INF = 1000000000;
inline int read(){
int out = 0,flag = 1; char c = getchar();
while (c < 48 || c > 57){if (c == '-') flag = -1; c = getchar();}
while (c >= 48 && c <= 57){out = (out << 3) + (out << 1) + c - 48; c = getchar();}
return out * flag;
}
ULL Hl[maxn],Hr[maxn];
char s[maxn],val[maxn];
int n,m;
int h[maxn],ne = 2;
int F[maxn],Siz[maxn],fa[maxn],vis[maxn],sum,rt;
LL ans;
struct EDGE{int to,nxt;}ed[maxm];
void build(int u,int v){
ed[ne] = (EDGE){v,h[u]}; h[u] = ne++;
ed[ne] = (EDGE){u,h[v]}; h[v] = ne++;
}
void init(){
for (int i = 1; i <= n; i++) vis[i] = h[i] = fa[i] = 0;
ne = 2; ans = 0;
}
void getrt(int u){
Siz[u] = 1; F[u] = 0;
Redge(u) if (!vis[to = ed[k].to] && to != fa[u]){
fa[to] = u; getrt(to);
Siz[u] += Siz[to];
F[u] = max(F[u],Siz[to]);
}
F[u] = max(F[u],sum - Siz[u]);
if (F[u] < F[rt]) rt = u;
}
int pre[maxn],post[maxn],dep[maxn];
ULL V[maxn],P[maxn];
void DFS(int u){
Siz[u] = 1;
Redge(u) if (!vis[to = ed[k].to] && to != fa[u]){
fa[to] = u; DFS(to);
Siz[u] += Siz[to];
}
}
void dfs1(int u){
V[u] = V[fa[u]] * 107 + val[u];
int d = (dep[u] - 1) % m + 1;
if (V[u] == Hl[dep[u]] && s[d % m + 1] == val[rt]){
//printf("find at %d\n",u);
ans += post[((m - d - 1) % m + m) % m];
}
if (V[u] == Hr[dep[u]] && s[m - d % m] == val[rt]){
//printf("rfind at %d\n",u);
ans += pre[((m - d - 1) % m + m) % m];
}
Redge(u) if (!vis[to = ed[k].to] && to != fa[u]){
fa[to] = u; dep[to] = dep[u] + 1;
dfs1(to);
}
}
void dfs2(int u){
int d = dep[u] % m;
if (V[u] == Hr[dep[u]]) post[d]++;
if (V[u] == Hl[dep[u]]) pre[d]++;
Redge(u) if (!vis[to = ed[k].to] && to != fa[u]){
fa[to] = u; dep[to] = dep[u] + 1;
dfs2(to);
}
}
void solve(int u){
vis[u] = true;
fa[u] = 0; DFS(u);
if (Siz[u] < m) return;
for (int i = min(Siz[u],m); i >= 0; i--) pre[i] = post[i] = 0;
pre[0] = post[0] = 1;
V[u] = 0;
Redge(u) if (!vis[to = ed[k].to]){
dep[to] = 1; fa[to] = u; dfs1(to);
dep[to] = 1; fa[to] = u; dfs2(to);
}
Redge(u) if (!vis[to = ed[k].to]){
sum = Siz[to]; F[rt = 0] = INF;
getrt(to); solve(rt);
}
}
int main(){
P[0] = 1;
for (int i = 1; i <= 1000000; i++) P[i] = P[i - 1] * 107;
int T = read();
while (T--){
init();
n = read(); m = read();
scanf("%s",s + 1);
for (int i = 1; i <= n; i++) val[i] = s[i];
for (int i = 1; i < n; i++) build(read(),read());
scanf("%s",s + 1);
for (int i = 1; i <= n; i++)
Hl[i] = Hl[i - 1] + P[i - 1] * s[(i - 1) % m + 1];
for (int i = 1; i <= n; i++)
Hr[i] = Hr[i - 1] + P[i - 1] * s[m - (i - 1) % m];
F[rt = 0] = INF; sum = n;
getrt(1); solve(rt);
printf("%lld\n",ans);
}
return 0;
}