2018焦作现场赛 H. Can You Solve the Harder Problem?(后缀数组+rmq/线段树+单调栈)
题意
在一个数组中,求所有本质不同子段的贡献和。
每个子段的贡献为该子段中的最大值。
\(n \leq 2e5 , T \leq 1000\)
思路
首先子段的贡献是子段中最大值,所以不难转化为求每个最大值对答案的贡献:
设 \(nxt[i]\) 代表 \(min\{j|i<j \&\&j<=n+1 \&\& a[j] > a[i] \}\) ,
则贡献 \(suf[i] = a[i] * (nxt[i] - i) + suf[nxt[i]]\), 对于 \(nxt[i]\) 的求解可用单调栈。
之后对于本质不同的子串,考虑用后缀数组。后缀数组对后缀排完序之后可以求其排名为 \(i\) 的子段与排名为 \(i-1\) 子段的 lcp值 $height [ i ] $该子段的贡献即为 \([height[i]+1, n]\) 这么一个区间的贡献。
对于\(height[i] == 0\) 的子段,子段对答案的贡献为整个子段的贡献;
对于 \(height[i]!=0\) 的子段,则字段对答案的贡献可以分成两段考虑:
首先设 \(p\) 为 \([sa[i], sa[i]+height[i]]\) 中最大值的下标,查询可用rmq或者线段树,以\(nxt[p]\)为分界点将\([sa[i] + height[i]+1, n]\) 分成两部分:
\((1):[sa[i]+height[i], nxt[i]-1]\) : 易得贡献 \(a[p]*(nxt[p]-sa[i]-height[i]-1)\)
\((2):[nxt[p], n]\):贡献为该子段的贡献。
最后可以对数组中的值进行离散化,提高效率。
Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int maxn = 1e6+10;
int T, n, s[maxn], has[maxn], pn;
struct SuffixArray {
int x[maxn], y[maxn], c[maxn];
int sa[maxn], rk[maxn], height[maxn];
void SA() {
int m = pn;
for (int i = 0; i <= m; ++i) c[i] = 0;
for (int i = 1; i <= n; ++i) ++c[(x[i]=s[i])];
for (int i = 1; i <= m; ++i) c[i] += c[i-1];
for (int i = n; i >= 1; --i) sa[c[x[i]]--] = i;
for (int p, k = 1; k <= n; k <<= 1) {
p = 0;
for (int i = n-k+1; i <= n; ++i) y[++p] = i ;
for (int i = 1; i <= n; ++i) {
if(sa[i] > k) y[++p] = sa[i] - k;
}
for (int i = 0; i <= m; ++i) c[i] = 0;
for (int i = 1; i <= n; ++i) ++c[x[y[i]]];
for (int i = 1; i <= m; ++i) c[i] += c[i-1];
for (int i = n; i >= 1; --i) sa[c[x[y[i]]]--] = y[i];
p = y[sa[1]] = 1;
for (int i = 2; i <= n; ++i) {
int a = sa[i]+k > n? -1 : x[sa[i]+k];
int b = sa[i-1]+k > n ? -1: x[sa[i-1]+k];
y[sa[i]] = (x[sa[i]] == x[sa[i-1]] && a == b) ? p : ++p;
}
swap(x, y);
if(p >= n) break;
m = p;
}
for (int i = 1; i <= n; ++i) rk[sa[i]] = i;
}
void getHeight() {
for (int k = 0, i = 1; i <= n; ++i) {
if(k) --k;
int j = sa[rk[i]-1];
while(s[i+k] == s[j+k]) ++k;
height[rk[i]] = k;
}
}
void build() {
SA();
getHeight();
}
void write() {
for (int i = 1; i <= n; ++i) printf("%d ", sa[i]); puts("");
for (int i = 1; i <= n; ++i) printf("%d ", rk[i]); puts("");
for (int i = 1; i <= n; ++i) printf("%d ", height[i]); puts("");
}
}sa;
int sta[maxn], top;
int nxt[maxn], st[maxn][30];
ll suf[maxn];
int query(int l, int r) {
int len = r - l + 1;
int d = 0;
while((1<<d+1) <= len) ++d;
int p = 1<<d;
if(s[st[l][d]] > s[st[r-p+1][d]]) return st[l][d];
return st[r-p+1][d];
}
int main() {
// freopen("input.in", "r", stdin);
scanf("%d", &T);
while(T--) {
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
scanf("%d", s+i);
has[i] = s[i];
}
sort(has+1, has+1+n);
pn = unique(has+1, has+1+n) - has - 1;
for (int i = 1; i <= n; ++i) {
s[i] = lower_bound(has+1, has+1+pn, s[i]) - has;
st[i][0] = i;
}
s[n+1] = inf, sta[(top = 1)] = n+1;
for (int i = n; i >= 1; --i) {
while(top && s[sta[top]] <= s[i]) --top;
nxt[i] = sta[top];
sta[++top] = i;
}
nxt[n+1] = st[n + 1][0] = n + 1;
for (int j = 1; j <= 20; ++j) {
int p = 1<<j-1, l = (1<<j)-1;
for (int i = 1; i + l <= n; ++i) {
if(s[st[i][j - 1]] > s[st[i + p][j - 1]]) st[i][j] = st[i][j - 1];
else st[i][j] = st[i + p][j - 1];
}
}
suf[n+1] = 0;
for (int i = n; i >= 1; --i) suf[i] = 1ll * has[s[i]] * (nxt[i] - i);
for (int i = n; i >= 1; --i) suf[i] += suf[nxt[i]];
sa.build();
ll ans = 0;
for (int i = 1; i <= n; ++i) {
int h = sa.height[i];
if(h == 0) {
ans += suf[sa.sa[i]];
} else {
int r = sa.sa[i] + h - 1;
int l = sa.sa[i];
int pos = query(l, r);
ans += suf[nxt[pos]] + 1ll * has[s[pos]] * (nxt[pos] - r - 1);
}
}
printf("%lld\n", ans);
}
return 0;
}
附水先生博客中的神仙读入挂优化后的代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int maxn = 2e5+10;
namespace fastIO{
#define BUF_SIZE 100000
#define OUT_SIZE 100000
#define ll long long
//fread->read
bool IOerror=0;
inline char nc(){
static char buf[BUF_SIZE],*p1=buf+BUF_SIZE,*pend=buf+BUF_SIZE;
if (p1==pend){
p1=buf; pend=buf+fread(buf,1,BUF_SIZE,stdin);
if (pend==p1){IOerror=1;return -1;}
//{printf("IO error!\n");system("pause");for (;;);exit(0);}
}
return *p1++;
}
inline bool blank(char ch){return ch==' '||ch=='\n'||ch=='\r'||ch=='\t';}
inline void read(int &x){
bool sign=0; char ch=nc(); x=0;
for (;blank(ch);ch=nc());
if (IOerror)return;
if (ch=='-')sign=1,ch=nc();
for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
if (sign)x=-x;
}
inline void read(ll &x){
bool sign=0; char ch=nc(); x=0;
for (;blank(ch);ch=nc());
if (IOerror)return;
if (ch=='-')sign=1,ch=nc();
for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
if (sign)x=-x;
}
inline void read(double &x){
bool sign=0; char ch=nc(); x=0;
for (;blank(ch);ch=nc());
if (IOerror)return;
if (ch=='-')sign=1,ch=nc();
for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
if (ch=='.'){
double tmp=1; ch=nc();
for (;ch>='0'&&ch<='9';ch=nc())tmp/=10.0,x+=tmp*(ch-'0');
}
if (sign)x=-x;
}
inline void read(char *s){
char ch=nc();
for (;blank(ch);ch=nc());
if (IOerror)return;
for (;!blank(ch)&&!IOerror;ch=nc())*s++=ch;
*s=0;
}
inline void read(char &c){
for (c=nc();blank(c);c=nc());
if (IOerror){c=-1;return;}
}
//getchar->read
inline void read1(int &x){
char ch;int bo=0;x=0;
for (ch=getchar();ch<'0'||ch>'9';ch=getchar())if (ch=='-')bo=1;
for (;ch>='0'&&ch<='9';x=x*10+ch-'0',ch=getchar());
if (bo)x=-x;
}
inline void read1(ll &x){
char ch;int bo=0;x=0;
for (ch=getchar();ch<'0'||ch>'9';ch=getchar())if (ch=='-')bo=1;
for (;ch>='0'&&ch<='9';x=x*10+ch-'0',ch=getchar());
if (bo)x=-x;
}
inline void read1(double &x){
char ch;int bo=0;x=0;
for (ch=getchar();ch<'0'||ch>'9';ch=getchar())if (ch=='-')bo=1;
for (;ch>='0'&&ch<='9';x=x*10+ch-'0',ch=getchar());
if (ch=='.'){
double tmp=1;
for (ch=getchar();ch>='0'&&ch<='9';tmp/=10.0,x+=tmp*(ch-'0'),ch=getchar());
}
if (bo)x=-x;
}
inline void read1(char *s){
char ch=getchar();
for (;blank(ch);ch=getchar());
for (;!blank(ch);ch=getchar())*s++=ch;
*s=0;
}
inline void read1(char &c){for (c=getchar();blank(c);c=getchar());}
//scanf->read
inline void read2(int &x){scanf("%d",&x);}
inline void read2(ll &x){
#ifdef _WIN32
scanf("%I64d",&x);
#else
#ifdef __linux
scanf("%lld",&x);
#else
puts("error:can't recognize the system!");
#endif
#endif
}
inline void read2(double &x){scanf("%lf",&x);}
inline void read2(char *s){scanf("%s",s);}
inline void read2(char &c){scanf(" %c",&c);}
inline void readln2(char *s){gets(s);}
//fwrite->write
struct Ostream_fwrite{
char *buf,*p1,*pend;
Ostream_fwrite(){buf=new char[BUF_SIZE];p1=buf;pend=buf+BUF_SIZE;}
void out(char ch){
if (p1==pend){
fwrite(buf,1,BUF_SIZE,stdout);p1=buf;
}
*p1++=ch;
}
void print(int x){
static char s[15],*s1;s1=s;
if (!x)*s1++='0';if (x<0)out('-'),x=-x;
while(x)*s1++=x%10+'0',x/=10;
while(s1--!=s)out(*s1);
}
void println(int x){
static char s[15],*s1;s1=s;
if (!x)*s1++='0';if (x<0)out('-'),x=-x;
while(x)*s1++=x%10+'0',x/=10;
while(s1--!=s)out(*s1); out('\n');
}
void print(ll x){
static char s[25],*s1;s1=s;
if (!x)*s1++='0';if (x<0)out('-'),x=-x;
while(x)*s1++=x%10+'0',x/=10;
while(s1--!=s)out(*s1);
}
void println(ll x){
static char s[25],*s1;s1=s;
if (!x)*s1++='0';if (x<0)out('-'),x=-x;
while(x)*s1++=x%10+'0',x/=10;
while(s1--!=s)out(*s1); out('\n');
}
void print(double x,int y){
static ll mul[]={1,10,100,1000,10000,100000,1000000,10000000,100000000,
1000000000,10000000000LL,100000000000LL,1000000000000LL,10000000000000LL,
100000000000000LL,1000000000000000LL,10000000000000000LL,100000000000000000LL};
if (x<-1e-12)out('-'),x=-x;x*=mul[y];
ll x1=(ll)floor(x); if (x-floor(x)>=0.5)++x1;
ll x2=x1/mul[y],x3=x1-x2*mul[y]; print(x2);
if (y>0){out('.'); for (size_t i=1;i<y&&x3*mul[i]<mul[y];out('0'),++i); print(x3);}
}
void println(double x,int y){print(x,y);out('\n');}
void print(char *s){while (*s)out(*s++);}
void println(char *s){while (*s)out(*s++);out('\n');}
void flush(){if (p1!=buf){fwrite(buf,1,p1-buf,stdout);p1=buf;}}
~Ostream_fwrite(){flush();}
}Ostream;
inline void print(int x){Ostream.print(x);}
inline void println(int x){Ostream.println(x);}
inline void print(char x){Ostream.out(x);}
inline void println(char x){Ostream.out(x);Ostream.out('\n');}
inline void print(ll x){Ostream.print(x);}
inline void println(ll x){Ostream.println(x);}
inline void print(double x,int y){Ostream.print(x,y);}
inline void println(double x,int y){Ostream.println(x,y);}
inline void print(char *s){Ostream.print(s);}
inline void println(char *s){Ostream.println(s);}
inline void println(){Ostream.out('\n');}
inline void flush(){Ostream.flush();}
//puts->write
char Out[OUT_SIZE],*o=Out;
inline void print1(int x){
static char buf[15];
char *p1=buf;if (!x)*p1++='0';if (x<0)*o++='-',x=-x;
while(x)*p1++=x%10+'0',x/=10;
while(p1--!=buf)*o++=*p1;
}
inline void println1(int x){print1(x);*o++='\n';}
inline void print1(ll x){
static char buf[25];
char *p1=buf;if (!x)*p1++='0';if (x<0)*o++='-',x=-x;
while(x)*p1++=x%10+'0',x/=10;
while(p1--!=buf)*o++=*p1;
}
inline void println1(ll x){print1(x);*o++='\n';}
inline void print1(char c){*o++=c;}
inline void println1(char c){*o++=c;*o++='\n';}
inline void print1(char *s){while (*s)*o++=*s++;}
inline void println1(char *s){print1(s);*o++='\n';}
inline void println1(){*o++='\n';}
inline void flush1(){if (o!=Out){if (*(o-1)=='\n')*--o=0;puts(Out);}}
struct puts_write{
~puts_write(){flush1();}
}_puts;
inline void print2(int x){printf("%d",x);}
inline void println2(int x){printf("%d\n",x);}
inline void print2(char x){printf("%c",x);}
inline void println2(char x){printf("%c\n",x);}
inline void print2(ll x){
#ifdef _WIN32
printf("%I64d",x);
#else
#ifdef __linux
printf("%lld",x);
#else
puts("error:can't recognize the system!");
#endif
#endif
}
inline void println2(ll x){print2(x);printf("\n");}
inline void println2(){printf("\n");}
#undef ll
#undef OUT_SIZE
#undef BUF_SIZE
};
using namespace fastIO;
int T, n, s[maxn], has[maxn], pn;
int x[maxn], y[maxn], c[maxn];
int sa[maxn], rk[maxn], height[maxn];
void SA() {
int m = pn;
for (int i = 0; i <= m; ++i) c[i] = 0;
for (int i = 1; i <= n; ++i) ++c[(x[i]=s[i])];
for (int i = 1; i <= m; ++i) c[i] += c[i-1];
for (int i = n; i >= 1; --i) sa[c[x[i]]--] = i;
for (int p, k = 1; k <= n; k <<= 1) {
p = 0;
for (int i = n-k+1; i <= n; ++i) y[++p] = i ;
for (int i = 1; i <= n; ++i) {
if(sa[i] > k) y[++p] = sa[i] - k;
}
for (int i = 0; i <= m; ++i) c[i] = 0;
for (int i = 1; i <= n; ++i) ++c[x[y[i]]];
for (int i = 1; i <= m; ++i) c[i] += c[i-1];
for (int i = n; i >= 1; --i) sa[c[x[y[i]]]--] = y[i];
p = y[sa[1]] = 1;
for (int i = 2; i <= n; ++i) {
int a = sa[i]+k > n? -1 : x[sa[i]+k];
int b = sa[i-1]+k > n ? -1: x[sa[i-1]+k];
y[sa[i]] = (x[sa[i]] == x[sa[i-1]] && a == b) ? p : ++p;
}
swap(x, y);
if(p >= n) break;
m = p;
}
for (int i = 1; i <= n; ++i) rk[sa[i]] = i;
}
void getHeight() {
for (int k = 0, i = 1; i <= n; ++i) {
if(k) --k;
int j = sa[rk[i]-1];
while(s[i+k] == s[j+k]) ++k;
height[rk[i]] = k;
}
}
void build() {
SA();
getHeight();
}
int sta[maxn], top;
int nxt[maxn], st[maxn][30];
int pw2[maxn];
ll suf[maxn];
int query(int l, int r) {
int len = r - l + 1;
int d = pw2[len]-1;
// while((1<<d+1) <= len) ++d;
int p = 1<<d;
if(s[st[l][d]] > s[st[r-p+1][d]]) return st[l][d];
return st[r-p+1][d];
}
int main() {
for (int i = 1; i < maxn; i <<= 1) pw2[i] = 1;
for (int i = 1; i < maxn; ++i) pw2[i] += pw2[i-1];
read(T);
while(T--) {
read(n);
for (int i = 1; i <= n; ++i) {
read(s[i]);
has[i] = s[i];
}
sort(has+1, has+1+n);
pn = unique(has+1, has+1+n) - has - 1;
for (int i = 1; i <= n; ++i) {
s[i] = lower_bound(has+1, has+1+pn, s[i]) - has;
st[i][0] = i;
}
s[n+1] = inf, sta[(top = 1)] = n+1;
for (int i = n; i >= 1; --i) {
while(top && s[sta[top]] <= s[i]) --top;
nxt[i] = sta[top];
sta[++top] = i;
}
nxt[n+1] = st[n + 1][0] = n + 1;
for (int j = 1; j <= 20; ++j) {
int p = 1<<j-1, l = (1<<j)-1;
for (int i = 1; i + l <= n; ++i) {
if(s[st[i][j - 1]] > s[st[i + p][j - 1]]) st[i][j] = st[i][j - 1];
else st[i][j] = st[i + p][j - 1];
}
}
suf[n+1] = 0;
for (int i = n; i >= 1; --i) suf[i] = 1ll * has[s[i]] * (nxt[i] - i);
for (int i = n; i >= 1; --i) suf[i] += suf[nxt[i]];
build();
ll ans = 0;
for (int i = 1; i <= n; ++i) {
int h = height[i];
if(h == 0) {
ans += suf[sa[i]];
} else {
int r = sa[i] + h - 1;
int l = sa[i];
int pos = query(l, r);
ans += suf[nxt[pos]] + 1ll * has[s[pos]] * (nxt[pos] - r - 1);
}
}
println(ans);
// printf("%lld\n", ans);
}
return 0;
}