LOJ 3089 「BJOI2019」奥术神杖——AC自动机DP+0/1分数规划
题目:https://loj.ac/problem/3089
没想到把根号之类的求对数变成算数平均值。写了个只能得15分的暴力。
#include<cstdio> #include<cstring> #include<algorithm> #define db double using namespace std; const int N=1505,K=10; const db eps=1e-8; int n,m,tot=1,c[N][K],fl[N]; db ans; int hd[N],xnt,to[N<<1],nxt[N<<1],vl[N],sm[N]; int l[N],v[N],dy[N],q[N]; db vp[N][N]; char s[N],ch[N],prn[N]; db pw(db x,int k) {db ret=1;while(k){if(k&1)ret*=x;x*=x;k>>=1;}return ret;} void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;} int ins(int len) { int cr=1; for(int i=1;i<=len;i++) { int w=ch[i]-'0'; if(!c[cr][w])c[cr][w]=++tot; cr=c[cr][w]; } return cr; } void get_fl() { int he=0,tl=0; for(int i=0;i<10;i++) if(c[1][i])q[++tl]=c[1][i],fl[c[1][i]]=1; else c[1][i]=1; while(he<tl) { int k=q[++he],pr=fl[k]; for(int i=0;i<10;i++) if(c[k][i])q[++tl]=c[k][i],fl[c[k][i]]=c[pr][i]; else c[k][i]=c[pr][i]; } for(int i=2;i<=tot;i++)add(fl[i],i); } void m_dfs(int cr) { sm[cr]=vl[cr]; for(int i=hd[cr],v;i;i=nxt[i]) m_dfs(v=to[i]), sm[cr]+=sm[v]; } db fnd(int ct,db tp) { db l=1,r=tp,ret=0; while(r-l>eps) { db mid=(l+r)/2; if(pw(mid,ct)<=tp)ret=mid,l=mid+eps; else r=mid-eps; } return ret; } void dfs(int cr,int nw) { if(cr>n) { m_dfs(1); db tp=1; int ct=0; for(int i=1;i<=m;i++) { tp*=vp[i][sm[dy[i]]]; ct+=sm[dy[i]];} db d=fnd(ct,tp); if(d>ans) { ans=d;for(int i=1;i<=n;i++)prn[i]=s[i];} return; } if(s[cr]!='.') { nw=c[nw][s[cr]-'0']; vl[nw]++; dfs(cr+1,nw); vl[nw]--; return; } int ynw=nw; for(int i=0;i<10;i++) { s[cr]=i+'0'; nw=c[ynw][i]; vl[nw]++; dfs(cr+1,nw); vl[nw]--; } s[cr]='.';//// } void solve1() { for(int i=1;i<=m;i++) { vp[i][0]=1;//[0]!! for(int j=1,lm=n-l[i]+1;j<=lm;j++) vp[i][j]=vp[i][j-1]*v[i]; } dfs(1,1); for(int i=1;i<=n;i++)putchar(prn[i]); puts(""); } namespace S2{ bool en[N],dp[N][N]; int pr[N][N],pw[N][N]; void dfs(int cr) { for(int i=hd[cr],v;i;i=nxt[i]) en[v=to[i]]|=en[cr], dfs(v); } void solve() { for(int i=1;i<=m;i++)en[dy[i]]=1; m_dfs(1); dp[1][1]=1; int ni=0,nj; for(int i=1;i<=n&&!ni;i++) for(int j=1;j<=tot;j++) if(dp[i][j]) { if(en[j]){ni=i;nj=j;break;} if(s[i]=='.') { for(int w=0;w<10;w++) if(!dp[i+1][c[j][w]]) { pr[i+1][c[j][w]]=j; pw[i+1][c[j][w]]=w; dp[i+1][c[j][w]]=1;} } else { int w=s[i]-'0'; for(int j=1;j<=tot;j++) if(!dp[i+1][c[j][w]]) { pr[i+1][c[j][w]]=j; pw[i+1][c[j][w]]=w; dp[i+1][c[j][w]]=1;} } } for(int i=ni;i>1;i--) s[i-1]=pw[i][nj]='0', nj=pr[i][nj]; for(int i=1;i<=n;i++) if(s[i]=='.')putchar('0'); else putchar(s[i]); puts(""); } } int main() { scanf("%d%d",&n,&m); scanf("%s",s+1); int ct=0; for(int i=1;i<=n;i++) ct+=(s[i]=='.'); bool fg=0; for(int i=1;i<=m;i++) { scanf("%s",ch+1); scanf("%d",&v[i]); l[i]=strlen(ch+1); dy[i]=ins(l[i]); if(i>1&&v[i]!=v[i-1])fg=1; } get_fl(); if(n<=6||ct<=3)solve1(); else if(!fg)S2::solve(); return 0; }
\( log(a^x) = x*log(a) \) , \( log(a*b) = log(a)+log(b) \) 。和乘法之类有关的应该考虑一下 log 。
求平均值就用 0/1 分数规划。
不要每次合法的时候都把方案存一遍。二分完再做一次,得到方案。自己用 pr[ ][ ] 和 pw[ ][ ] 记录 dp[ i ][ j ] 是从 dp[ i-1 ] 的哪里转移来的、填了什么值。
二分的 r 的范围是 max( log(v) ) 。
m 个串里没出现过的字符都是等价的。记录一下 “出现过的字符” , 如果有没出现过的就选一个加入字符集,然后用那个字符集来转移即可。
这样就能支持 1e-8 的精度了。不过 1e-4 就够了。
#include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #define db double using namespace std; db Mx(db a,db b){return a>b?a:b;} const int N=1505,K=10; const db INF=3e7,eps=1e-8; int n,m,tot=1,c[N][K],fl[N],q[N]; int ct[N],pr[N][N],pw[N][N]; db v[N],dp[N][N]; char s[N],ch[N],prn[N]; int tw[K],top; bool vis[K]; db get_lg(int d) { db l=0,r=10,ret=0;//ret=0 while(r-l>eps) { db mid=(l+r)/2; if(pow(10,mid)<=d)ret=mid,l=mid+eps; else r=mid-eps; } return ret; } void ins(int len,db d) { int cr=1; for(int i=1;i<=len;i++) { int w=ch[i]-'0'; if(!vis[w])vis[w]=1,tw[++top]=w; if(!c[cr][w])c[cr][w]=++tot; cr=c[cr][w]; } v[cr]=d; ct[cr]=1; } void get_fl() { int he=0,tl=0; for(int i=0,j;i<10;i++) if((j=c[1][i]))q[++tl]=j,fl[j]=1; else c[1][i]=1; while(he<tl) { int k=q[++he],pr=fl[k]; v[k]+=v[pr]; ct[k]+=ct[pr]; for(int i=0,j;i<10;i++) if((j=c[k][i]))q[++tl]=j,fl[j]=c[pr][i]; else c[k][i]=c[pr][i]; } } int chk(db L) { for(int i=0;i<=n;i++) for(int j=1;j<=tot;j++)dp[i][j]=-INF; dp[0][1]=0; for(int i=0;i<n;i++) { for(int j=1;j<=tot;j++) if(dp[i][j]>-INF) { if(s[i+1]!='.') { int w=s[i+1]-'0',tj=c[j][w]; db d=dp[i][j]+v[tj]-L*ct[tj]; if(d>dp[i+1][tj]) { dp[i+1][tj]=d;pr[i+1][tj]=j;pw[i+1][tj]=w; } continue; } for(int k=1;k<=top;k++) { int w=tw[k],tj=c[j][w]; db d=dp[i][j]+v[tj]-L*ct[tj]; if(d>dp[i+1][tj]) { dp[i+1][tj]=d;pr[i+1][tj]=j; pw[i+1][tj]=w; } } } } int j=1; for(int i=2;i<=tot;i++)if(dp[n][i]>dp[n][j])j=i; return j; } int main() { scanf("%d%d",&n,&m);scanf("%s",s+1); db d,l=0,r=0,ans; for(int i=1;i<=m;i++) { scanf("%s%lf",ch+1,&d); d=get_lg(d); r=Mx(r,d);//Mx not sm int len=strlen(ch+1); ins(len,d); } get_fl(); for(int i=0;i<10;i++)if(!vis[i]){tw[++top]=i;break;} while(r-l>eps) { db mid=(l+r)/2; if(dp[n][chk(mid)]>0) //>0 not >=0 for almost always can =0 { ans=mid; l=mid+eps;} else r=mid-eps; } int j=chk(ans); for(int i=n;i;i--) prn[i]=pw[i][j]+'0', j=pr[i][j]; printf("%s\n",prn+1); return 0; }