洛谷P2178 品酒大会
题意:若两个字符开始的后面r个字符都一样,则称这两个字符是r相似的。它们也是r-1相似的。
对于r∈[0,n)分别求有多少种方案,其中权值最大方案权值是多少。此处权值是选出的两个字符的权值之积。
解:后缀自动机吊打后缀数组!!!
先看第一问,我们考虑后缀自动机上每个节点的贡献。显然cnt>1的节点才会有贡献。
它会对r ∈ len[fail[x]] + 1 ~ len[x]这一段的答案产生C(cntx,2)的贡献。这就是一个区间加法。
有个小问题,如果r减少那么相应的可选的其实会变多,但是此处我们不统计,那些会在以另一个字符结尾的别的节点上考虑到。
这样第一问就解决了。第二问?发现问题很大...一个节点的每个串都是结尾相同,开头不同。那么开头的权值之积就不好维护。
因为结尾相同,所以考虑反着建后缀自动机,然后就变成了开头相同了。那么如何维护乘积最大值呢?
考虑到一个节点的末尾所在位置,也就是它的right集合。显然就是fail树的子树中所有在主链上的节点。
于是每个点维护子树最大值即可。因为有负数所以还要最小值。
这样一个节点对第二问的贡献就是区间取max了。这两问的操作都可以用线段树搞定。
1 #include <cstdio> 2 #include <cstring> 3 #include <algorithm> 4 5 typedef long long LL; 6 const int N = 600010; 7 8 int tr[N][26], tot = 1, fail[N], len[N], cnt[N], bin[N], topo[N], last = 1, val[N], n; 9 int max1[N], max2[N], min1[N], min2[N]; 10 char str[N]; 11 LL tag[N << 1], large[N << 1]; 12 13 inline void insert(char c) { 14 int f = c - 'a'; 15 int p = last, np = ++tot; 16 last = np; 17 len[np] = len[p] + 1; 18 cnt[np] = 1; 19 while(p && !tr[p][f]) { 20 tr[p][f] = np; 21 p = fail[p]; 22 } 23 if(!p) { 24 fail[np] = 1; 25 } 26 else { 27 int Q = tr[p][f]; 28 if(len[Q] == len[p] + 1) { 29 fail[np] = Q; 30 } 31 else { 32 int nQ = ++tot; 33 len[nQ] = len[p] + 1; 34 fail[nQ] = fail[Q]; 35 fail[Q] = fail[np] = nQ; 36 memcpy(tr[nQ], tr[Q], sizeof(tr[Q])); 37 while(tr[p][f] == Q) { 38 tr[p][f] = nQ; 39 p = fail[p]; 40 } 41 } 42 } 43 return; 44 } 45 46 inline void update(int x, int y) { 47 int t[4]; 48 t[0] = max1[x]; 49 t[1] = max2[x]; 50 t[2] = max1[y]; 51 t[3] = max2[y]; 52 std::sort(t, t + 4); 53 max1[x] = t[3]; 54 max2[x] = t[2]; 55 t[0] = min1[x]; 56 t[1] = min2[x]; 57 t[2] = min1[y]; 58 t[3] = min2[y]; 59 std::sort(t, t + 4); 60 min1[x] = t[0]; 61 min2[x] = t[1]; 62 return; 63 } 64 65 inline void prework() { 66 for(int i = 1; i <= tot; i++) { 67 bin[len[i]]++; 68 } 69 for(int i = 1; i <= tot; i++) { 70 bin[i] += bin[i - 1]; 71 } 72 for(int i = 1; i <= tot; i++) { 73 topo[bin[len[i]]--] = i; 74 } 75 for(int i = tot; i >= 1; i--) { 76 int a = topo[i]; 77 cnt[fail[a]] += cnt[a]; 78 update(fail[a], a); 79 } 80 return; 81 } 82 83 inline void pushdown(int o) { 84 if(tag[o]) { 85 tag[o << 1] += tag[o]; 86 tag[o << 1 | 1] += tag[o]; 87 tag[o] = 0; 88 } 89 large[o << 1] = std::max(large[o << 1], large[o]); 90 large[o << 1 | 1] = std::max(large[o << 1 | 1], large[o]); 91 return; 92 } 93 94 void add(int L, int R, LL v, int l, int r, int o) { 95 if(L <= l && r <= R) { 96 tag[o] += v; 97 return; 98 } 99 int mid = (l + r) >> 1; 100 pushdown(o); 101 if(L <= mid) { 102 add(L, R, v, l, mid, o << 1); 103 } 104 if(mid < R) { 105 add(L, R, v, mid + 1, r, o << 1 | 1); 106 } 107 return; 108 } 109 110 void out(int l, int r, int o) { 111 if(l == r) { 112 if(r != n) { 113 printf("%lld %lld \n", tag[o], tag[o] ? large[o] : 0); 114 } 115 return; 116 } 117 int mid = (l + r) >> 1; 118 pushdown(o); 119 out(l, mid, o << 1); 120 out(mid + 1, r, o << 1 | 1); 121 return; 122 } 123 124 void change(int L, int R, LL v, int l, int r, int o) { 125 if(v <= large[o]) { 126 return; 127 } 128 if(L <= l && r <= R) { 129 large[o] = v; 130 return; 131 } 132 int mid = (l + r) >> 1; 133 pushdown(o); 134 if(L <= mid) { 135 change(L, R, v, l, mid, o << 1); 136 } 137 if(mid < R) { 138 change(L, R, v, mid + 1, r, o << 1 | 1); 139 } 140 return; 141 } 142 143 int main() { 144 memset(max1, ~0x3f, sizeof(max1)); 145 memset(max2, ~0x3f, sizeof(max2)); 146 memset(min1, 0x3f, sizeof(min1)); 147 memset(min2, 0x3f, sizeof(min2)); 148 memset(large, ~0x3f, sizeof(large)); 149 scanf("%d", &n); 150 scanf("%s", str + 1); 151 int l1 = -0x3f3f3f3f, l2 = -0x3f3f3f3f, s1 = 0x3f3f3f3f, s2 = 0x3f3f3f3f; 152 for(int i = 1; i <= n; i++) { 153 scanf("%d", &val[i]); 154 if(l1 < val[i]) { 155 l2 = l1; 156 l1 = val[i]; 157 } 158 else if(l2 < val[i]) { 159 l2 = val[i]; 160 } 161 if(s1 > val[i]) { 162 s2 = s1; 163 s1 = val[i]; 164 } 165 else if(s2 > val[i]) { 166 s2 = val[i]; 167 } 168 } 169 for(int i = n; i >= 1; i--) { 170 insert(str[i]); 171 max1[last] = min1[last] = val[i]; 172 } 173 prework(); 174 // 175 for(int i = 2; i <= tot; i++) { 176 if(cnt[i] < 2) { 177 continue; 178 } 179 // len[fail[i]] + 1 ~ len[i] 180 add(len[fail[i]] + 1, len[i], 1ll * cnt[i] * (cnt[i] - 1) / 2, 1, n, 1); 181 change(len[fail[i]] + 1, len[i], std::max(1ll * max1[i] * max2[i], 1ll * min1[i] * min2[i]), 1, n, 1); 182 } 183 184 printf("%lld %lld \n", 1ll * n * (n - 1) / 2, std::max(1ll * l1 * l2, 1ll * s1 * s2)); 185 out(1, n, 1); 186 return 0; 187 }