[BJOI2019]奥术神杖(AC自动机,DP,分数规划)
题目大意:
给出一个长度 $n$ 的字符串 $T$,只由数字和点组成。你可以把每个点替换成一个任意的数字。再给出 $m$ 个数字串 $S_i$,第 $i$ 个权值为 $t_i$。
对于一个替换方案,这样定义它的价值:
如果数字串 $S_i$ 在 $T$ 中出现了,那么将 $t_i$ 加入多重集。如果出现多次也要加多次。
它的价值就是这个多重集元素的几何平均数(所有 $c$ 个数的乘积开 $c$ 次方根)。
请构造出一个替换方案,使得这个值最大。不用输出这个价值。
$1\le n\le 1500,1\le \sum|S_i|\le 1500,1\le t_i\le 10^9$。
首先这个奇怪的式子是乘积的形式。如果将它取对数:(其实不一定要用 $\ln$,任意底数都可以)
$$\dfrac{1}{c}\sum\limits_{i=1}^c\ln x_i$$
$$\dfrac{\sum\limits_{i=1}^c\ln x_i}{\sum\limits_{i=1}^c 1}$$
这就是分数规划经典模式了。
二分 $x$,看看这个式子的值能不能 $>v$。可以就缩小左边界,否则缩小右边界。
$$\dfrac{1}{c}\sum\limits_{i=1}^c\ln x_i>v$$
$$\sum\limits_{i=1}^c\ln x_i>vc$$
$$\sum\limits_{i=1}^c(\ln x_i-v)>0$$
设第 $i$ 个串的新权值 $w_i=\ln x_i-v$,那么就是要找出一个方案使得 $w_i$ 之和最大,看看是否 $>0$ 即可。
如何求最大值?实际上是个套路 DP 了。
先建出所有 $S_i$ 串的 AC 自动机。然后设 $wsum_u$ 为从 $u$ 开始跳 fail 指针,跳过的所有点的 $w_i$ 之和。要记录这个是因为在 AC 自动机上走到点 $u$ 时,沿着 $u$ 跳 fail 能跳到的所有点都是可以匹配的。
令 $f[i][j]$ 表示按照 $T$ 串走了 $i$ 步,走到点 $j$ 的最大值。
可以从 $f[i][j]+wsum_c$ 转移到 $f[i+1][c]$。其中 $T_{i+1}$ 已经确定时 $c$ 就是对应的儿子,否则就要枚举替换成什么字符,再转移到相应的儿子。
由于要输出方案,要记录从哪个状态转移过来。
时间复杂度 $O(n\sum|S_i|\log)$。
#include<bits/stdc++.h> using namespace std; const int maxn=1555; #define FOR(i,a,b) for(int i=(a);i<=(b);i++) #define ROF(i,a,b) for(int i=(a);i>=(b);i--) #define MEM(x,v) memset(x,v,sizeof(x)) inline int read(){ char ch=getchar();int x=0,f=0; while(ch<'0' || ch>'9') f|=ch=='-',ch=getchar(); while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar(); return f?-x:x; } int n,m,nd,ch[maxn][10],fail[maxn],cnt[maxn],q[maxn],h,r,fr[maxn][maxn],id[maxn][maxn]; double val[maxn],hhh[maxn],dp[maxn][maxn]; char s[maxn],t[maxn],tmp[maxn]; void insert(const char *s,int x){ int now=0,l=strlen(s+1); FOR(i,1,l){ int p=s[i]-'0'; if(!ch[now][p]) ch[now][p]=++nd; now=ch[now][p]; } val[now]+=log(x); cnt[now]++; } void build(){ h=1;r=0; FOR(i,0,9) if(ch[0][i]) q[++r]=ch[0][i]; while(h<=r){ int u=q[h++]; val[u]+=val[fail[u]]; cnt[u]+=cnt[fail[u]]; FOR(i,0,9) if(ch[u][i]) fail[q[++r]=ch[u][i]]=ch[fail[u]][i]; else ch[u][i]=ch[fail[u]][i]; } } void trans(int i,int j,int k){ int c=ch[j][k]; if(dp[i][j]+hhh[c]>dp[i+1][c]){ dp[i+1][c]=dp[i][j]+hhh[c]; fr[i+1][c]=j; id[i+1][c]=k; } } bool check(double x){ FOR(i,0,nd) hhh[i]=val[i]-cnt[i]*x; FOR(i,0,n) FOR(j,0,nd) dp[i][j]=-1e9; dp[0][0]=0; FOR(i,0,n-1) FOR(j,0,nd){ if(s[i+1]=='.') FOR(k,0,9) trans(i,j,k); else trans(i,j,s[i+1]-'0'); } int mxid=0; FOR(i,1,nd) if(dp[n][i]>dp[n][mxid]) mxid=i; if(dp[n][mxid]<=0) return false; int at=mxid; ROF(i,n,1){ t[i]=id[i][at]+'0'; at=fr[i][at]; } return true; } int main(){ n=read();m=read(); scanf("%s",s+1); FOR(i,1,m){ scanf("%s",tmp+1); insert(tmp,read()); } build(); double l=0,r=log(1e9); while(r-l>1e-8){ double mid=(l+r)/2; if(check(mid)) l=mid; else r=mid; } check(l); printf("%s\n",t+1); }