[BJOI2019] 奥术神杖 [取log+AC自动机+dp]
题面
思路
首先,看到这个乘起来开根号的形式,应该能想到用取$\log$的方式做一个转化:
$\sqrt[n]{\prod_i a_i}=\frac{1}{n}\sum_i \log_b a_i$
这里我们把$b$取到$e$,就是$\ln a_i$了,不过实际上$b$取什么都没有问题
那么,这个问题就转化为了求所有匹配的宝石序列的最大平均值
遇到这种多模式串、单模板串的情况,应当第一时间想到AC自动机
我们建立模式串的AC自动机,并在上面跑dp即可完成题目的求解
建立AC自动机的时候,注意每个节点需要继承$fail$树上所有祖先的信息!
遇到这种有对选取的元素求平均值的最值的情况,应当第一时间想到0-1分数规划
我们二分最大平均值的答案,设当前为$C$
那么若有一组匹配方式能达到这个$C$,或以上,则有:
$\frac{1}{siz}\sum_{i=1}^{siz}w_i\geq C$
$\sum_{i=1}^{siz}w_i\geq C\ast siz$
$\sum_{i=1}^{siz}(w_i-C)\geq 0$
所以我们把每一个取过$\log$的元素减去当前二分的$C$,在AC自动机上跑dp
这样的好处是避免了需要在dp中维护已匹配元素个数的一个维度,可以优化一个$n$的时间复杂度
建立AC自动机后,设$dp[i][u]$表示当前遍历完成了模板串的前$i$个字符,匹配指针位置在AC自动机节点$u$上的情况时的最大值。
若模板串的下一个字符是确定的,就直接走到对应的儿子即可
否则需要更新每一个$dp[i+1][son_u]$的值
详细的更新方式见代码
Code
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cassert>
#include<queue>
#include<cmath>
#define ll long long
using namespace std;
inline int read(){
int re=0,flag=1;char ch=getchar();
while(!isdigit(ch)){
if(ch=='-') flag=-1;
ch=getchar();
}
while(isdigit(ch)) re=(re<<1)+(re<<3)+ch-'0',ch=getchar();
return re*flag;
}
int n,m;double w[1510];
char a[1510],s[1510][1510];
int ch[1510][10],cntn=0,num[1510],fail[1510];
double sum[1510];
//AC Automaton
inline void insert(int x,int len){
int i,cur=0,tmp;
for(i=1;i<=len;i++){
tmp=s[x][i]-'0';
if(!ch[cur][tmp]) ch[cur][tmp]=++cntn;
cur=ch[cur][tmp];
}
num[cur]++;sum[cur]+=w[x];
}
queue<int>q;
inline void build(){
int i,u,v;
for(i=0;i<10;i++){
if(!ch[0][i]) continue;
q.push(ch[0][i]);fail[ch[0][i]]=0;
}
while(!q.empty()){
u=q.front();q.pop();
num[u]+=num[fail[u]];
sum[u]+=sum[fail[u]];
for(i=0;i<10;i++){
v=ch[u][i];
if(v) fail[v]=ch[fail[u]][i],q.push(v);
else ch[u][i]=ch[fail[u]][i];
}
}
}
//Dynamic Programming
double dp[1510][1510];
int from[1510][1510][2],endpos;
inline bool check(double mid){
int i,j,k,son;
for(i=0;i<=n;i++) for(j=0;j<=cntn;j++) dp[i][j]=-2e20;
for(i=0;i<=cntn;i++){
sum[i]-=mid*(double)num[i];//cut the value according to binary search process
}
dp[0][0]=0;
for(i=1;i<=n;i++){
for(j=0;j<=cntn;j++){
if(dp[i-1][j]==-2e20) continue;
if(a[i]!='.'){//character is fixed in original S
son=ch[j][a[i]-'0'];
if(dp[i][son]<dp[i-1][j]+sum[son]){
dp[i][son]=dp[i-1][j]+sum[son];
from[i][son][0]=j;//record the source of the maximum value
from[i][son][1]=a[i]-'0';//record the corresponding character
}
}
else{//character is unfixed
for(k=0;k<10;k++){
son=ch[j][k];
if(dp[i][son]<dp[i-1][j]+sum[son]){
dp[i][son]=dp[i-1][j]+sum[son];
from[i][son][0]=j;
from[i][son][1]=k;
}
}
}
}
}
int pos=0;
for(i=1;i<=cntn;i++) if(dp[n][i]>dp[n][pos]) pos=i;
for(i=0;i<=cntn;i++) sum[i]+=mid*(double)num[i];//repair the value cut
endpos=pos;return dp[n][pos]>0;//determine if largest value is over zero
}
char re[1510];
int main(){
n=read();m=read();int i;
scanf("%s",a+1);
for(i=1;i<=m;i++){
scanf("%s",s[i]+1);
w[i]=read();
w[i]=log((double)w[i]);
insert(i,strlen(s[i]+1));
}
build();
double l=0,r=log(1e9+7),mid,ans=0;
while(r-l>1e-6){//binary search
mid=(l+r)*0.5;
if(check(mid)) ans=mid,l=mid;
else r=mid;
}
check(ans);
for(i=n;i>=1;i--){//get the answer string
re[i]=from[i][endpos][1]+'0';
endpos=from[i][endpos][0];
}
for(i=1;i<=n;i++) putchar(re[i]);
putchar('\n');
}