HDU-3553 Just a String (二分 + 后缀数组)
题意:找出文本串中字典序第 k 大的字符串
思路:
首先我们不能仅仅按后缀数组排完序后每个字符串的大小来找,因为重复字符也参与排名,比如 AAB 2, 结果是 A 而不是 AA。
注:以下第 i 个后缀均指排完序后第 i 小的后缀。
所以我们二分找第 k 大的字符串位于哪个区间,假定我们现在确定目标位于后缀区间 \([le, ri]\) (排完序的),我们求出 \(LCP(le, ri) = x\),
并找出最小的LCP对应的后缀 \(mid\) 如果 \(x \times(ri - le + 1) \geq k\) ,那么我们就可以确定该字符串的长度
\(len = k / (ri - le + 1) + k \% (ri - le + 1)\), 这个公式自己写个样例很容易得出,然后我们考虑
\(x \times(ri - le + 1) < k\) 的情况,我们首先先将 \(k = k - x \times (ri - mid)\),然后看区间 \([le, mid]\) 的字符串总数量 \(sum\) 是否大于等于\(k\),如果大于等
于\(k\),那么我们可以确定目标一定在区间 \([le, mid]\) 中, 因为区间 \([le, mid]\) 的任意字符串字典序一定小于区间 \([mid + 1, ri]\) 中的所有长度大于 \(x\) 的字符串。
如果区间 \([le, mid]\) 的字符串总数量小于 \(k\), 那么目标就一定在区间 \([mid + 1, ri]\) 中,那我们把
\(k = k + x \times (ri - mid) - sum\) ,因为可以保证区间 \([le, mid]\) 中所有字符串的都小于 \(k\),因为我们最好控制每次开始查询的 \([le, ri]\) 区间里的字符串还未被减去, 这样方便我们编程。关于具体细节看下面代码。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn = 1e5 + 50;
int Sa[maxn], Height[maxn], Tax[maxn], Rank[maxn], tp[maxn], a[maxn], n, m;
LL sum[maxn];
int ca = 0;
char str[maxn];
void Rsort(){
for(int i = 0; i <= m; i++) Tax[i] = 0;
for(int i = 1; i <= n; i++) Tax[Rank[tp[i]]]++;
for(int i = 1; i <= m; i++) Tax[i] += Tax[i - 1];
for(int i = n; i >= 1; i--) Sa[Tax[Rank[tp[i]]]--] = tp[i];
}
int cmp(int *f, int x, int y, int w){
if(x + w > n || y + w > n) return 0; // 注意防止越界,多组输入的时候这条必须有
return f[x] == f[y] && f[x + w] == f[y + w];
}
void Suffix(){
for(int i = 1; i <= n; i++) Rank[i] = a[i], tp[i] = i;
m = 500, Rsort();
int p = 0;
for(int w = 1, i; p < n; w <<= 1, m = p){
for(p = 0, i = n - w + 1; i <= n; i++) tp[++p] = i;
for(i = 1; i <= n; i++) if(Sa[i] > w) tp[++p] = Sa[i] - w;
Rsort();
for(int i = 1;i <= n;i++) tp[i] = Rank[i];
Rank[Sa[1]] = p = 1;
for(int i = 2; i <= n; i++) Rank[Sa[i]] = cmp(tp, Sa[i], Sa[i - 1], w) ? p : ++p;
}
int j, k = 0;
for(int i = 1; i <= n; Height[Rank[i++]] = k){
for(k = k ? k - 1 : k, j = Sa[Rank[i] - 1]; i + k <= n && j + k <= n && a[i + k] == a[j + k]; ++k);
}
}
int min_st(int p1, int p2){
if(Height[p1] <= Height[p2]) return p1;
else return p2;
}
int dpmi[maxn][30];
void RMQ(){
for(int i = 1; i <= n; i++) dpmi[i][0] = i;
for(int j = 1; (1 << j) <= n; j++){
for(int i = 1; i + (1 << j) - 1 <= n; i++){
dpmi[i][j] = min_st(dpmi[i][j - 1], dpmi[i + (1 << (j - 1))][j - 1]);
}
}
}
int QueryMin(int le, int ri){
int k = log2(ri - le + 1);
return min_st(dpmi[le][k], dpmi[ri - (1 << k) + 1][k]);
}
int QueryLcp(int le, int ri){
if(le > ri) swap(le, ri);
le++;
return QueryMin(le, ri);
}
void Solve(LL k){
int le = 1, ri = n;
while(le <= ri){
if(le == ri){
for(int i = 0; i < k; i++){
printf("%c", str[Sa[le] + i]);
}
printf("\n");
break;
}
int mid = QueryLcp(le, ri) - 1;
if(k <= 1LL * Height[mid + 1] * (ri - le + 1)){
int len = k / (ri - le + 1);
if(k % (ri - le + 1)) len++;
for(int i = 0; i < len; i++){
printf("%c", str[Sa[le] + i]);
}
printf("\n");
break;
} else {
k -= 1LL * (ri - mid) * Height[mid + 1];
if(sum[mid] - sum[le - 1] >= k){
ri = mid;
} else {
k += 1LL * (ri - mid) * Height[mid + 1];
k -= (sum[mid] - sum[le - 1]);
le = mid + 1;
}
}
}
}
int main(int argc, char const *argv[])
{
int tt;
scanf("%d", &tt);
while(tt--){
scanf("%s", str + 1);
n = strlen(str + 1);
LL k;
scanf("%lld", &k);
for(int i = 1; i <= n; i++) {
a[i] = str[i];
}
Suffix();
for(int i = 1; i <= n; i++){
sum[i] = sum[i - 1] + n - Sa[i] + 1;
}
RMQ();
printf("Case %d: ", ++ca);
Solve(k);
}
return 0;
}