题目的意思是叶子不超过20个……听说当初zjoi不少人被坑

分别对每个叶子以它为根dfs出20个dfs树,这样整个树的任何一个子串,都是某个dfs树上一个点到它的一个子孙的路径

每个dfs树,根到叶子相当于一个串,这样相当于统计20*19个串的不同字串数目

这样我们很容易想到建立多串SAM(广义后缀树)统计不同的子串数即可

  1 type node=record
  2        po,next:longint;
  3      end;
  4 
  5 var go:array[0..4000010,0..10] of longint;
  6     fa,mx:array[0..4000010] of longint;
  7     c,d,p:array[0..100010] of longint;
  8     e:array[0..200010] of node;
  9     v:array[0..100010] of boolean;
 10     t,i,len,x,y,n,m:longint;
 11     ans:int64;
 12 
 13 procedure add(x,y:longint);
 14   begin
 15     inc(len);
 16     e[len].po:=y;
 17     e[len].next:=p[x];
 18     p[x]:=len;
 19     inc(d[y]);
 20   end;
 21 
 22 function work(c,last:longint):longint;
 23   var np,q,p,nq:longint;
 24   begin
 25     p:=last;
 26     if go[p,c]=0 then
 27     begin
 28       inc(t); work:=t; np:=t;
 29       mx[np]:=mx[p]+1;
 30       while (p>0) and (go[p,c]=0) do
 31       begin
 32         go[p,c]:=np;
 33         p:=fa[p];
 34       end;
 35       if p=0 then fa[np]:=1
 36       else begin
 37         q:=go[p,c];
 38         if mx[q]=mx[p]+1 then fa[np]:=q
 39         else begin
 40           inc(t); nq:=t;
 41           mx[nq]:=mx[p]+1;
 42           go[nq]:=go[q];
 43           fa[nq]:=fa[q];
 44           fa[q]:=nq; fa[np]:=nq;
 45           while go[p,c]=q do
 46           begin
 47             go[p,c]:=nq;
 48             p:=fa[p];
 49           end;
 50         end;
 51       end;
 52     end
 53     else begin
 54       q:=go[p,c];
 55       if mx[q]=mx[p]+1 then exit(q)
 56       else begin
 57         inc(t); nq:=t;
 58         mx[nq]:=mx[p]+1;
 59         go[nq]:=go[q];
 60         fa[nq]:=fa[q];
 61         fa[q]:=nq;
 62         while go[p,c]=q do
 63         begin
 64           go[p,c]:=nq;
 65           p:=fa[p];
 66         end;
 67       end;
 68       exit(go[last,c]);
 69     end;
 70   end;
 71 
 72 procedure dfs(x,last:longint);
 73   var i,y:longint;
 74   begin
 75     last:=work(c[x],last);
 76     v[x]:=true;
 77     i:=p[x];
 78     while i<>0 do
 79     begin
 80       y:=e[i].po;
 81       if not v[y] then dfs(y,last);
 82       i:=e[i].next;
 83     end;
 84   end;
 85 
 86 begin
 87   readln(n,m);
 88   for i:=1 to n do
 89     read(c[i]);
 90   for i:=1 to n-1 do
 91   begin
 92     readln(x,y);
 93     add(x,y);
 94     add(y,x);
 95   end;
 96   t:=1;
 97   for i:=1 to n do
 98     if d[i]=1 then
 99     begin
100       fillchar(v,sizeof(v),false);
101       dfs(i,1);
102     end;
103   for i:=2 to t do
104     ans:=ans+mx[i]-mx[fa[i]];
105   writeln(ans);
106 end.
View Code

 

posted on 2015-07-21 17:42  acphile  阅读(338)  评论(0编辑  收藏  举报