BZOJ 4477: [Jsoi2015]字符串树 可持久化字典树

我们考虑如果我们能快速的得出一条路线上的字符串组成的字典树,那么问题就迎刃而解了。开太多的字典树,开不下,我们可持久化以下就好了。 持久化出根到每个结点的字典树,然后ans(a) + ans(b) - 2 * ans(lca)即可。

可持久化字典树应该如何操作呢,我们考虑,对于一个字典树x,我们向其加入一个字符串s1,得到一棵新的字典树y,两个字典树绝大多数部分均相同,只有遍历s1的这条路径上,会有所差异。所以我们考虑,对于字典树y,我们直接把和x相同的部分,直接指向x,不另行新建。这样子,每次版本更新,我们只会新建len(加入字符串长度)个的节点

 1 #include <cstdio>
 2 #include <cstring>
 3 #include <cmath>
 4 #include <algorithm>
 5 #include <stack>
 6 using namespace std;
 7 const int MAXN = 1000005,MAXM = 1000005;
 8 int cnt,n,tot,q;
 9 int head[MAXN],rt[MAXN],dep[MAXN],siz[MAXN],to[MAXM],nxt[MAXM];
10 int p[MAXN][30];
11 int ch[MAXN][26];
12 char str[MAXM][15];
13 void add(int x,int y,char *s)
14 {
15     nxt[++cnt] = head[x];
16     to[cnt] = y;
17     head[x] = cnt;
18     memcpy(str[cnt],s,sizeof(str[cnt]));
19 }
20 void dfs(int x,int frm)
21 {
22     for (int i = head[x];i;i = nxt[i])
23     {
24         if (to[i] == frm) continue;
25         p[to[i]][0] = x;
26         dep[to[i]] = dep[x] + 1;
27         int u = rt[x],v = rt[to[i]] = ++tot,lenn = strlen(str[i] + 1);
28         for (int j = 1;j <= lenn;j++)
29         {
30             for (int o = 0;o < 26;o++)
31             {
32                 siz[v] += siz[ch[u][o]];
33                 ch[v][o] = ch[u][o];
34             }
35             siz[v]++;
36             u = ch[u][str[i][j] - 'a'];
37             ch[v][str[i][j] - 'a'] = ++tot;
38             v = tot;
39         }
40         siz[v]++;
41         dfs(to[i],x);
42     }
43 }
44 int lca(int x,int y)
45 {
46     int jqe = log2(n);
47     if (dep[x] < dep[y]) swap(x,y);
48     for (int i = jqe;i >= 0;i--)
49         if (dep[x] - (1 << i) >= dep[y])
50             x = p[x][i];
51     if (x == y) return x;
52     for (int i = jqe;i >= 0;i--)
53         if (p[x][i] != p[y][i]) x = p[x][i],y = p[y][i];
54     return p[x][0];
55 }
56 void lca_init()
57 {
58     int jqe = log2(n);
59     for (int i = 1;i <= jqe;i++)
60         for (int j = 1;j <= n;j++)
61             p[j][i] = p[p[j][i - 1]][i - 1];
62 }
63 int solve(int x,char *s)
64 {
65     int u = rt[x],lenn = strlen(s + 1);
66     for (int i = 1;i <= lenn;i++)
67         u = ch[u][s[i] - 'a']; 
68     return siz[u];
69 }
70 int main()
71 {
72     scanf("%d",&n);
73     for (int i = 1;i <= n - 1;i++)
74     {
75         int u,v;
76         char s[15];
77         scanf("%d%d%s",&u,&v,s + 1);
78         add(u,v,s);
79         add(v,u,s);
80     }
81     dep[1] = 1;
82     dfs(1,0);
83     lca_init();
84     scanf("%d",&q);
85     for (int i = 1;i <= q;i++)
86     {
87         int u,v;    
88         char s[15];
89         scanf("%d%d%s",&u,&v,s + 1);
90         int t = lca(u,v);
91         printf("%d\n",solve(u,s) + solve(v,s) - 2 * solve(t,s));
92     }
93     return 0;
94 }
95 ?

 

posted @ 2019-03-22 00:00  IAT14  阅读(168)  评论(0编辑  收藏  举报