bzoj4477 4477: [Jsoi2015]字符串树 可持久化字典树

pass

 1 #include <cstdio>
 2 #include <cstring>
 3 #include <cmath>
 4 #include <algorithm>
 5 using namespace std;
 6 const int MAXN = 1100000;
 7 int cnt,n,tot,q;
 8 int ch[MAXN][26],head[MAXN],rt[MAXN],dep[MAXN],siz[MAXN],to[MAXN],nxt[MAXN];
 9 int p[MAXN][30];
10 char str[MAXN][15];
11 void add(int x,int y,char *s)
12 {
13     nxt[++cnt] = head[x];
14     to[cnt] = y;
15     head[x] = cnt;
16     memcpy(str[cnt],s,sizeof(str[cnt]));
17 }
18 void dfs(int x,int frm)
19 {
20     for (int i = head[x];i;i = nxt[i])
21     {
22         if (to[i] == frm)
23             continue;
24         p[to[i]][0] = x;
25         dep[to[i]] = dep[x] + 1;
26         int u = rt[x],v = rt[to[i]] = ++tot,lenn = strlen(str[i] + 1);
27         for (int j = 1;j <= lenn;j++)
28         {
29             for (int o = 0;o <= 25;o++)
30             {
31                 siz[v] += siz[ch[u][o]];
32                 ch[v][o] = ch[u][o];
33             }
34             siz[v]++;
35             u = ch[u][str[i][j] - 'a'];
36             ch[v][str[i][j] - 'a'] = ++tot;
37             v = tot;
38         }
39         siz[v]++;
40         dfs(to[i],x);
41     }
42 }
43 int lca(int x,int y)
44 {
45     int t = log2(n);
46     if (dep[x] < dep[y])
47         swap(x,y);
48     for (int i = t;i >= 0;i--)
49         if (dep[x] - (1 << i) >= dep[y])
50             x = p[x][i];
51     if (x == y)
52         return x;
53     for (int i = t;i >= 0;i--)
54         if (p[x][i] != p[y][i])
55         {
56             x = p[x][i];
57             y = p[y][i];
58         }
59     return p[x][0];
60 }
61 int lca_init()
62 {
63     int t = log2(n);
64     for (int i = 1;i <= t;i++)
65         for (int j = 1;j <= n;j++)
66             p[j][i] = p[p[j][i - 1]][i - 1];
67 }
68 int solve(int x,char *s)
69 {
70     int u = rt[x],lenn = strlen(s + 1);
71     for (int i = 1;i <= lenn;i++)
72         u = ch[u][s[i] - 'a'];
73     return siz[u];
74 }
75 int main()
76 {
77     scanf("%d",&n);
78     int tu,tv;
79     char ts[15];
80     for (int i = 1;i <= n - 1;i++)
81     {
82         scanf("%d%d%s",&tu,&tv,ts + 1);
83         add(tu,tv,ts);
84         add(tv,tu,ts);
85     }
86     dep[1] = 1;
87     dfs(1,0);
88     lca_init();
89     scanf("%d",&q);
90     for (int i = 1;i <= q;i++)
91     {
92         scanf("%d%d%s",&tu,&tv,ts + 1);
93         int t = lca(tu,tv);
94         printf("%d\n",solve(tu,ts) + solve(tv,ts) - 2 * solve(t,ts));
95     }
96 }

 

posted @ 2020-02-23 14:36  IAT14  阅读(199)  评论(0编辑  收藏  举报