【2017 江苏集训】子串

题意

  给定一个长度为 \(n\) 的由 \(['0'\cdots '9']\) 组成的字符串 \(s\)\(v[i,j]\) 表示由字符串 \(s\)\(i\) 到第 \(j\) 位组成的十进制数字。
  将它的某一个上升序列定义为:将这个字符串切割成 \(m\) 段不含前导 \('0'\) 的串,切点分别为 \(k_1,k_2\cdots k_{m-1}\),使得 \(v[1,k_1]\lt v[k_1+1,k_2]\lt \cdots\lt v[k_{m-1},k_m]\)
  请你求出该字符串 \(s\) 的上升序列个数,答案对 \(10^9+7\) 取模。
  \(n\le 5000\)

题解

  很套路的 \(\text{dp}\)以前也做过(好像山东 2018 年有一道集训题叫湿纸巾鼠)
  设 \(dp(i,j)\) 表示令 \(j\) 为一个切点,上一个切点为 \(i\) 时,得到的上升序列的方案数。这里令一个切点为一段字符串的最后一位
  若 \((i,i+(i-j)]\gt (j,i]\),则 \(dp(i,i+(i-j))+=dp(j,i)\),否则 \(dp(i,i+(i-j)+1)+=dp(j,i)\)
  问题转化为求两段字符串的 \(\text{LCP}\)
  做法比较多,一个比较暴力的做法就是 \(O(\log n)\) 哈希(我不知道为啥随便设个模数都能过,可能是数据太水了),但造个全是 \(1\) 的数据就能卡到 \(6s+\)
  考虑 \(O(1)\) 方法求:\(\text{SA}\)(scb 神仙怎么天天切这种神仙板子鸭);\(O(n^2)\) 预处理后 \(O(1)\) 回答(当时看到这个做法才发现自己严重降智商了)

  由于数据是随机造的,哈希特判了前三位就直接跑了 \(\text{rk1}\)……

#include<bits/stdc++.h>
#define ll long long
#define N 5001
#define mul 131
#define mod 1000000007
using namespace std;
inline int read(){
    int x=0; bool f=1; char c=getchar();
    for(;!isdigit(c); c=getchar()) if(c=='-') f=0;
    for(; isdigit(c); c=getchar()) x=(x<<3)+(x<<1)+(c^'0');
    if(f) return x; return 0-x;
}
int n,ans;
char s[N];
void dfs(int x, ll lst, ll cur){
    if(x>n){
        if(cur) return;
        ans=(ans+1)%mod;
        //for(int i=1; i<=cnt; ++i) printf("%lld ",stk[i]); putchar('\n');
        return;
    }
    ll tmp=cur*10+s[x]-'0';
    dfs(x+1,lst,tmp);
    if((x==n || s[x+1]!='0') && tmp>lst) dfs(x+1,tmp,0);
}
int dp[N][N],h[N],p[N];
inline int getHash(int l, int r){
    return (((ll)h[r] - (ll)h[l-1]*p[r-l+1]%mod) % mod + mod) % mod;
}
int main(){
    //freopen("2.in","r",stdin);
    //freopen("2.out","w",stdout);
    n=read();
    scanf("%s",s+1);
    if(s[1]=='0'){puts("0"); return 0;}
    if(n<=0){
        dfs(1,0,0), printf("%d\n",ans);
    }
    else{
        p[0]=1;
        for(int i=1; i<=n; ++i) h[i]=((ll)h[i-1]*mul+s[i]-'0')%mod, p[i]=(ll)p[i-1]*mul%mod;
        for(int i=1; i<=n; ++i) dp[0][i]=1;
        for(int i=1; i<=n; ++i){
            if(s[i+1]=='0') continue;
            for(int j=0; j<i; ++j) if(i-j<=n-i+1){
                int l=1, r=i-j, mid, ret=r+1;
                if(s[j+1]!=s[i+1]) ret=1;
                else if(s[j+2]!=s[i+2]) ret=2;
                else if(s[j+3]!=s[i+3]) ret=3;
                else{
                    while(l<=r){
                        mid=l+r>>1;
                        if(getHash(j+1,j+mid)!=getHash(i+1,i+mid)) ret=mid, r=mid-1;
                        else l=mid+1;
                    }
                }
                if(ret<i-j+1 && s[j+ret]<s[i+ret]) (dp[i][i+(i-j)]+=dp[j][i])%=mod;
                else if(i+(i-j)+1<=n) (dp[i][i+(i-j)+1]+=dp[j][i])%=mod;
            }
            for(int j=i; j<=n; ++j) (dp[i][j]+=dp[i][j-1])%=mod;
            //printf("%d:\n",i);
            //for(int j=1; j<=n; ++j) printf("%d ",dp[i][j]); putchar('\n');
        } 
        int ans=0;
        for(int i=0; i<=n; ++i) (ans+=dp[i][n])%=mod;
        printf("%d\n",ans);
    }
    return 0;
}
#include<bits/stdc++.h>
#define ll long long
#define N 5001
#define mul 131
#define mod 1000000007
using namespace std;
inline int read(){
    int x=0; bool f=1; char c=getchar();
    for(;!isdigit(c); c=getchar()) if(c=='-') f=0;
    for(; isdigit(c); c=getchar()) x=(x<<3)+(x<<1)+(c^'0');
    if(f) return x; return 0-x;
}
int n,ans,lcp[N][N];
char s[N];
void dfs(int x, ll lst, ll cur){
    if(x>n){
        if(cur) return;
        ans=(ans+1)%mod;
        //for(int i=1; i<=cnt; ++i) printf("%lld ",stk[i]); putchar('\n');
        return;
    }
    ll tmp=cur*10+s[x]-'0';
    dfs(x+1,lst,tmp);
    if((x==n || s[x+1]!='0') && tmp>lst) dfs(x+1,tmp,0);
}
int dp[N][N];
int main(){
    //freopen("2.in","r",stdin);
    //freopen("2.out","w",stdout);
    n=read();
    scanf("%s",s+1);
    if(s[1]=='0'){puts("0"); return 0;}
    if(n<=0){
        dfs(1,0,0), printf("%d\n",ans);
    }
    else{
        for(int i=n; i>=1; --i)
            for(int j=i+1; j<=n; ++j)
                if(s[i]==s[j]) lcp[i][j]=lcp[i+1][j+1]+1;
        for(int i=1; i<=n; ++i) dp[0][i]=1;
        for(int i=1; i<=n; ++i){
            if(s[i+1]=='0') continue;
            for(int j=0; j<i; ++j) if(i-j<=n-i+1){
                int ret=lcp[j+1][i+1]+1;
                if(ret<i-j+1 && s[j+ret]<s[i+ret]) (dp[i][i+(i-j)]+=dp[j][i])%=mod;
                else if(i+(i-j)+1<=n) (dp[i][i+(i-j)+1]+=dp[j][i])%=mod;
            }
            for(int j=i; j<=n; ++j) (dp[i][j]+=dp[i][j-1])%=mod;
            //printf("%d:\n",i);
            //for(int j=1; j<=n; ++j) printf("%d ",dp[i][j]); putchar('\n');
        } 
        int ans=0;
        for(int i=0; i<=n; ++i) (ans+=dp[i][n])%=mod;
        printf("%d\n",ans);
    }
    return 0;
}
posted @ 2019-08-23 15:30  大本营  阅读(170)  评论(0编辑  收藏  举报