2018焦作现场赛 H. Can You Solve the Harder Problem?(后缀数组+rmq/线段树+单调栈)

题意

在一个数组中,求所有本质不同子段的贡献和。

每个子段的贡献为该子段中的最大值。

\(n \leq 2e5 , T \leq 1000\)

传送门

思路

首先子段的贡献是子段中最大值,所以不难转化为求每个最大值对答案的贡献:

\(nxt[i]\) 代表 \(min\{j|i<j \&\&j<=n+1 \&\& a[j] > a[i] \}\)

则贡献 \(suf[i] = a[i] * (nxt[i] - i) + suf[nxt[i]]\), 对于 \(nxt[i]\) 的求解可用单调栈。

之后对于本质不同的子串,考虑用后缀数组。后缀数组对后缀排完序之后可以求其排名为 \(i\) 的子段与排名为 \(i-1\) 子段的 lcp值 $height [ i ] $该子段的贡献即为 \([height[i]+1, n]\) 这么一个区间的贡献。

对于\(height[i] == 0\) 的子段,子段对答案的贡献为整个子段的贡献;

对于 \(height[i]!=0\) 的子段,则字段对答案的贡献可以分成两段考虑:

首先设 \(p\)\([sa[i], sa[i]+height[i]]\) 中最大值的下标,查询可用rmq或者线段树,以\(nxt[p]\)为分界点将\([sa[i] + height[i]+1, n]\) 分成两部分:

\((1):[sa[i]+height[i], nxt[i]-1]\) : 易得贡献 \(a[p]*(nxt[p]-sa[i]-height[i]-1)\)

\((2):[nxt[p], n]\):贡献为该子段的贡献。

最后可以对数组中的值进行离散化,提高效率。

Code

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const int inf = 0x3f3f3f3f;
const int maxn = 1e6+10;

int T, n, s[maxn], has[maxn], pn;
struct SuffixArray {
    int x[maxn], y[maxn], c[maxn];
    int sa[maxn], rk[maxn], height[maxn];

    void SA() {
        int m = pn;
        for (int i = 0; i <= m; ++i) c[i] = 0;
        for (int i = 1; i <= n; ++i) ++c[(x[i]=s[i])];
        for (int i = 1; i <= m; ++i) c[i] += c[i-1];
        for (int i = n; i >= 1; --i) sa[c[x[i]]--] = i;

        for (int p, k = 1; k <= n; k <<= 1) {
            p = 0;
            for (int i = n-k+1; i <= n; ++i) y[++p] = i ;
            for (int i = 1; i <= n; ++i) {
                if(sa[i] > k) y[++p] = sa[i] - k;
            }

            for (int i = 0; i <= m; ++i) c[i] = 0;
            for (int i = 1; i <= n; ++i) ++c[x[y[i]]];
            for (int i = 1; i <= m; ++i) c[i] += c[i-1];
            for (int i = n; i >= 1; --i) sa[c[x[y[i]]]--] = y[i];

            p = y[sa[1]] = 1;
            for (int i = 2; i <= n; ++i) {
                int a = sa[i]+k > n? -1 : x[sa[i]+k];
                int b = sa[i-1]+k > n ? -1: x[sa[i-1]+k];
                y[sa[i]] = (x[sa[i]] == x[sa[i-1]] && a == b) ? p : ++p;
            }
            swap(x, y);
            if(p >= n) break;
            m = p;
        }
        for (int i = 1; i <= n; ++i) rk[sa[i]] = i;
    }

    void getHeight() {
        for (int k = 0, i = 1; i <= n; ++i) {
            if(k) --k;
            int j = sa[rk[i]-1];
            while(s[i+k] == s[j+k]) ++k;
            height[rk[i]] = k;
        }
    }

    void build() {
        SA();
        getHeight();
    }

    void write() {
        for (int i = 1; i <= n; ++i) printf("%d ", sa[i]); puts("");
        for (int i = 1; i <= n; ++i) printf("%d ", rk[i]); puts("");
        for (int i = 1; i <= n; ++i) printf("%d ", height[i]); puts("");
    }
}sa;

int sta[maxn], top;
int nxt[maxn], st[maxn][30];
ll suf[maxn];

int query(int l, int r) {
    int len = r - l + 1;
    int d = 0;
    while((1<<d+1) <= len) ++d;
    int p = 1<<d;
    if(s[st[l][d]] > s[st[r-p+1][d]]) return st[l][d];
    return st[r-p+1][d];
}

int main() {
//    freopen("input.in", "r", stdin);
    scanf("%d", &T);
    while(T--) {
        scanf("%d", &n);
        for (int i = 1; i <= n; ++i) {
            scanf("%d", s+i);
            has[i] = s[i];
        }

        sort(has+1, has+1+n);
        pn = unique(has+1, has+1+n) - has - 1;
        for (int i = 1; i <= n; ++i) {
            s[i] = lower_bound(has+1, has+1+pn, s[i]) - has;
            st[i][0] = i;
        }

        s[n+1] = inf, sta[(top = 1)] = n+1;
        for (int i = n; i >= 1; --i) {
            while(top && s[sta[top]] <= s[i]) --top;
            nxt[i] = sta[top];
            sta[++top] = i;
        }
        nxt[n+1] = st[n + 1][0] = n + 1;

        for (int j = 1; j <= 20; ++j) {
            int p = 1<<j-1, l = (1<<j)-1;
            for (int i = 1; i + l <= n; ++i) {
                if(s[st[i][j - 1]] > s[st[i + p][j - 1]]) st[i][j] = st[i][j - 1];
                else st[i][j] = st[i + p][j - 1];
            }
        }

        suf[n+1] = 0;
        for (int i = n; i >= 1; --i) suf[i] = 1ll * has[s[i]] * (nxt[i] - i);
        for (int i = n; i >= 1; --i) suf[i] += suf[nxt[i]];

        sa.build();

        ll ans = 0;
        for (int i = 1; i <= n; ++i) {
            int h = sa.height[i];
            if(h == 0) {
                ans += suf[sa.sa[i]];
            } else {
                int r = sa.sa[i] + h - 1;
                int l = sa.sa[i];
                int pos = query(l, r);
                ans += suf[nxt[pos]] + 1ll * has[s[pos]] * (nxt[pos] - r - 1);
            }
        }
        printf("%lld\n", ans);
    }
    return 0;
}

水先生博客中的神仙读入挂优化后的代码:

#include <bits/stdc++.h>
 
using namespace std;
 
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int maxn = 2e5+10;
 
namespace fastIO{
#define BUF_SIZE 100000
#define OUT_SIZE 100000
#define ll long long
    //fread->read
    bool IOerror=0;
    inline char nc(){
        static char buf[BUF_SIZE],*p1=buf+BUF_SIZE,*pend=buf+BUF_SIZE;
        if (p1==pend){
            p1=buf; pend=buf+fread(buf,1,BUF_SIZE,stdin);
            if (pend==p1){IOerror=1;return -1;}
            //{printf("IO error!\n");system("pause");for (;;);exit(0);}
        }
        return *p1++;
    }
    inline bool blank(char ch){return ch==' '||ch=='\n'||ch=='\r'||ch=='\t';}
    inline void read(int &x){
        bool sign=0; char ch=nc(); x=0;
        for (;blank(ch);ch=nc());
        if (IOerror)return;
        if (ch=='-')sign=1,ch=nc();
        for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
        if (sign)x=-x;
    }
    inline void read(ll &x){
        bool sign=0; char ch=nc(); x=0;
        for (;blank(ch);ch=nc());
        if (IOerror)return;
        if (ch=='-')sign=1,ch=nc();
        for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
        if (sign)x=-x;
    }
    inline void read(double &x){
        bool sign=0; char ch=nc(); x=0;
        for (;blank(ch);ch=nc());
        if (IOerror)return;
        if (ch=='-')sign=1,ch=nc();
        for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
        if (ch=='.'){
            double tmp=1; ch=nc();
            for (;ch>='0'&&ch<='9';ch=nc())tmp/=10.0,x+=tmp*(ch-'0');
        }
        if (sign)x=-x;
    }
    inline void read(char *s){
        char ch=nc();
        for (;blank(ch);ch=nc());
        if (IOerror)return;
        for (;!blank(ch)&&!IOerror;ch=nc())*s++=ch;
        *s=0;
    }
    inline void read(char &c){
        for (c=nc();blank(c);c=nc());
        if (IOerror){c=-1;return;}
    }
    //getchar->read
    inline void read1(int &x){
        char ch;int bo=0;x=0;
        for (ch=getchar();ch<'0'||ch>'9';ch=getchar())if (ch=='-')bo=1;
        for (;ch>='0'&&ch<='9';x=x*10+ch-'0',ch=getchar());
        if (bo)x=-x;
    }
    inline void read1(ll &x){
        char ch;int bo=0;x=0;
        for (ch=getchar();ch<'0'||ch>'9';ch=getchar())if (ch=='-')bo=1;
        for (;ch>='0'&&ch<='9';x=x*10+ch-'0',ch=getchar());
        if (bo)x=-x;
    }
    inline void read1(double &x){
        char ch;int bo=0;x=0;
        for (ch=getchar();ch<'0'||ch>'9';ch=getchar())if (ch=='-')bo=1;
        for (;ch>='0'&&ch<='9';x=x*10+ch-'0',ch=getchar());
        if (ch=='.'){
            double tmp=1;
            for (ch=getchar();ch>='0'&&ch<='9';tmp/=10.0,x+=tmp*(ch-'0'),ch=getchar());
        }
        if (bo)x=-x;
    }
    inline void read1(char *s){
        char ch=getchar();
        for (;blank(ch);ch=getchar());
        for (;!blank(ch);ch=getchar())*s++=ch;
        *s=0;
    }
    inline void read1(char &c){for (c=getchar();blank(c);c=getchar());}
    //scanf->read
    inline void read2(int &x){scanf("%d",&x);}
    inline void read2(ll &x){
#ifdef _WIN32
        scanf("%I64d",&x);
#else
#ifdef __linux
        scanf("%lld",&x);
#else
        puts("error:can't recognize the system!");
#endif
#endif
    }
    inline void read2(double &x){scanf("%lf",&x);}
    inline void read2(char *s){scanf("%s",s);}
    inline void read2(char &c){scanf(" %c",&c);}
    inline void readln2(char *s){gets(s);}
    //fwrite->write
    struct Ostream_fwrite{
        char *buf,*p1,*pend;
        Ostream_fwrite(){buf=new char[BUF_SIZE];p1=buf;pend=buf+BUF_SIZE;}
        void out(char ch){
            if (p1==pend){
                fwrite(buf,1,BUF_SIZE,stdout);p1=buf;
            }
            *p1++=ch;
        }
        void print(int x){
            static char s[15],*s1;s1=s;
            if (!x)*s1++='0';if (x<0)out('-'),x=-x;
            while(x)*s1++=x%10+'0',x/=10;
            while(s1--!=s)out(*s1);
        }
        void println(int x){
            static char s[15],*s1;s1=s;
            if (!x)*s1++='0';if (x<0)out('-'),x=-x;
            while(x)*s1++=x%10+'0',x/=10;
            while(s1--!=s)out(*s1); out('\n');
        }
        void print(ll x){
            static char s[25],*s1;s1=s;
            if (!x)*s1++='0';if (x<0)out('-'),x=-x;
            while(x)*s1++=x%10+'0',x/=10;
            while(s1--!=s)out(*s1);
        }
        void println(ll x){
            static char s[25],*s1;s1=s;
            if (!x)*s1++='0';if (x<0)out('-'),x=-x;
            while(x)*s1++=x%10+'0',x/=10;
            while(s1--!=s)out(*s1); out('\n');
        }
        void print(double x,int y){
            static ll mul[]={1,10,100,1000,10000,100000,1000000,10000000,100000000,
                             1000000000,10000000000LL,100000000000LL,1000000000000LL,10000000000000LL,
                             100000000000000LL,1000000000000000LL,10000000000000000LL,100000000000000000LL};
            if (x<-1e-12)out('-'),x=-x;x*=mul[y];
            ll x1=(ll)floor(x); if (x-floor(x)>=0.5)++x1;
            ll x2=x1/mul[y],x3=x1-x2*mul[y]; print(x2);
            if (y>0){out('.'); for (size_t i=1;i<y&&x3*mul[i]<mul[y];out('0'),++i); print(x3);}
        }
        void println(double x,int y){print(x,y);out('\n');}
        void print(char *s){while (*s)out(*s++);}
        void println(char *s){while (*s)out(*s++);out('\n');}
        void flush(){if (p1!=buf){fwrite(buf,1,p1-buf,stdout);p1=buf;}}
        ~Ostream_fwrite(){flush();}
    }Ostream;
    inline void print(int x){Ostream.print(x);}
    inline void println(int x){Ostream.println(x);}
    inline void print(char x){Ostream.out(x);}
    inline void println(char x){Ostream.out(x);Ostream.out('\n');}
    inline void print(ll x){Ostream.print(x);}
    inline void println(ll x){Ostream.println(x);}
    inline void print(double x,int y){Ostream.print(x,y);}
    inline void println(double x,int y){Ostream.println(x,y);}
    inline void print(char *s){Ostream.print(s);}
    inline void println(char *s){Ostream.println(s);}
    inline void println(){Ostream.out('\n');}
    inline void flush(){Ostream.flush();}
    //puts->write
    char Out[OUT_SIZE],*o=Out;
    inline void print1(int x){
        static char buf[15];
        char *p1=buf;if (!x)*p1++='0';if (x<0)*o++='-',x=-x;
        while(x)*p1++=x%10+'0',x/=10;
        while(p1--!=buf)*o++=*p1;
    }
    inline void println1(int x){print1(x);*o++='\n';}
    inline void print1(ll x){
        static char buf[25];
        char *p1=buf;if (!x)*p1++='0';if (x<0)*o++='-',x=-x;
        while(x)*p1++=x%10+'0',x/=10;
        while(p1--!=buf)*o++=*p1;
    }
    inline void println1(ll x){print1(x);*o++='\n';}
    inline void print1(char c){*o++=c;}
    inline void println1(char c){*o++=c;*o++='\n';}
    inline void print1(char *s){while (*s)*o++=*s++;}
    inline void println1(char *s){print1(s);*o++='\n';}
    inline void println1(){*o++='\n';}
    inline void flush1(){if (o!=Out){if (*(o-1)=='\n')*--o=0;puts(Out);}}
    struct puts_write{
        ~puts_write(){flush1();}
    }_puts;
    inline void print2(int x){printf("%d",x);}
    inline void println2(int x){printf("%d\n",x);}
    inline void print2(char x){printf("%c",x);}
    inline void println2(char x){printf("%c\n",x);}
    inline void print2(ll x){
#ifdef _WIN32
        printf("%I64d",x);
#else
#ifdef __linux
        printf("%lld",x);
#else
        puts("error:can't recognize the system!");
#endif
#endif
    }
    inline void println2(ll x){print2(x);printf("\n");}
    inline void println2(){printf("\n");}
#undef ll
#undef OUT_SIZE
#undef BUF_SIZE
};
using namespace fastIO;
 
int T, n, s[maxn], has[maxn], pn;
int x[maxn], y[maxn], c[maxn];
int sa[maxn], rk[maxn], height[maxn];
 
void SA() {
    int m = pn;
    for (int i = 0; i <= m; ++i) c[i] = 0;
    for (int i = 1; i <= n; ++i) ++c[(x[i]=s[i])];
    for (int i = 1; i <= m; ++i) c[i] += c[i-1];
    for (int i = n; i >= 1; --i) sa[c[x[i]]--] = i;
 
    for (int p, k = 1; k <= n; k <<= 1) {
        p = 0;
        for (int i = n-k+1; i <= n; ++i) y[++p] = i ;
        for (int i = 1; i <= n; ++i) {
            if(sa[i] > k) y[++p] = sa[i] - k;
        }
 
        for (int i = 0; i <= m; ++i) c[i] = 0;
        for (int i = 1; i <= n; ++i) ++c[x[y[i]]];
        for (int i = 1; i <= m; ++i) c[i] += c[i-1];
        for (int i = n; i >= 1; --i) sa[c[x[y[i]]]--] = y[i];
 
        p = y[sa[1]] = 1;
        for (int i = 2; i <= n; ++i) {
            int a = sa[i]+k > n? -1 : x[sa[i]+k];
            int b = sa[i-1]+k > n ? -1: x[sa[i-1]+k];
            y[sa[i]] = (x[sa[i]] == x[sa[i-1]] && a == b) ? p : ++p;
        }
        swap(x, y);
        if(p >= n) break;
        m = p;
    }
    for (int i = 1; i <= n; ++i) rk[sa[i]] = i;
}
 
void getHeight() {
    for (int k = 0, i = 1; i <= n; ++i) {
        if(k) --k;
        int j = sa[rk[i]-1];
        while(s[i+k] == s[j+k]) ++k;
        height[rk[i]] = k;
    }
}
 
void build() {
    SA();
    getHeight();
}
 
int sta[maxn], top;
int nxt[maxn], st[maxn][30];
int pw2[maxn];
ll suf[maxn];
 
int query(int l, int r) {
    int len = r - l + 1;
    int d = pw2[len]-1;
//    while((1<<d+1) <= len) ++d;
    int p = 1<<d;
    if(s[st[l][d]] > s[st[r-p+1][d]]) return st[l][d];
    return st[r-p+1][d];
}
 
int main() {
    for (int i = 1; i < maxn; i <<= 1) pw2[i] = 1;
    for (int i = 1; i < maxn; ++i) pw2[i] += pw2[i-1];
    read(T);
    while(T--) {
        read(n);
        for (int i = 1; i <= n; ++i) {
            read(s[i]);
            has[i] = s[i];
        }
 
        sort(has+1, has+1+n);
        pn = unique(has+1, has+1+n) - has - 1;
        for (int i = 1; i <= n; ++i) {
            s[i] = lower_bound(has+1, has+1+pn, s[i]) - has;
            st[i][0] = i;
        }
 
        s[n+1] = inf, sta[(top = 1)] = n+1;
        for (int i = n; i >= 1; --i) {
            while(top && s[sta[top]] <= s[i]) --top;
            nxt[i] = sta[top];
            sta[++top] = i;
        }
        nxt[n+1] = st[n + 1][0] = n + 1;
 
        for (int j = 1; j <= 20; ++j) {
            int p = 1<<j-1, l = (1<<j)-1;
            for (int i = 1; i + l <= n; ++i) {
                if(s[st[i][j - 1]] > s[st[i + p][j - 1]]) st[i][j] = st[i][j - 1];
                else st[i][j] = st[i + p][j - 1];
            }
        }
 
        suf[n+1] = 0;
        for (int i = n; i >= 1; --i) suf[i] = 1ll * has[s[i]] * (nxt[i] - i);
        for (int i = n; i >= 1; --i) suf[i] += suf[nxt[i]];
 
        build();
 
        ll ans = 0;
        for (int i = 1; i <= n; ++i) {
            int h = height[i];
            if(h == 0) {
                ans += suf[sa[i]];
            } else {
                int r = sa[i] + h - 1;
                int l = sa[i];
                int pos = query(l, r);
                ans += suf[nxt[pos]] + 1ll * has[s[pos]] * (nxt[pos] - r - 1);
            }
        }
        println(ans);
//        printf("%lld\n", ans);
    }
    return 0;
}
posted @ 2019-10-11 17:20  Acerkoo  阅读(337)  评论(0编辑  收藏  举报