P5319 [BJOI2019]奥术神杖
用时:正常的时间大概120min
这题spj...我能说我对着不存在的bug找了半天吗...我是傻逼(1/1)
因为要求神力值最大的字符串,不用求这个数值,
可以把每个值取\(\log\),
把几何平均数转化为算数平均数。
\(\sqrt[n]{\prod a_i}=\frac{1}{n}\sum \log a_i\)
多模式串匹配\(\rightarrow AC\)自动机
求平均值\(\rightarrow 0/1\)分数规划
设当前二分到的答案为\(mid\),
\(\frac{1}{n} \sum \limits_{i=1} ^{n} \log a_i \ge mid\)
\(\sum \limits_{i=1} ^{n} \log a_i \ge n \times mid\)
\(\sum \limits_{i=1} ^{n} \log (a_i-mid) \ge 0\)
在\(AC\)自动机上\(DP\)。
设:
\(f[i][j]\)表示当前匹配到第\(i\)个字符,且第\(i\)个字符在\(trie\)树上的编号为\(j\)。
\(f[i][trie[j].son[x]] = max(f[i-1][j] + val[trie[j].son[x]])\)
其中,
\(x \in [0,9]\),若\(x\)在原串上已确定则不用枚举\(x\)。
\(val[i] = \log (a_i- cnt_i \times mid)\)
每次转移时,记录这个状态是从哪个字符及\(trie\)树上的编号转移过来的。
注意,建立\(AC\)自动机时,每个节点都要继承\(fail\)的所有祖先的信息,即
trie[u].cnt += trie[trie[u].fail].cnt;
trie[u].sum += trie[trie[u].fail].sum;
最后判断是否\(max(f[n][x]) > 0\),
如果是,则记录当前字符串。
#include<cstdio>
#include<iostream>
#include<cmath>
#include<cstring>
#include<queue>
#define Mogeko qwq
using namespace std;
const int maxn = 1505;
const int INF = 0x3f3f3f3f;
const double eps = 1e-6;
int n,m,now,num;
double l,r,mid;
char t[maxn],s[maxn],T[maxn];
double val[maxn],f[maxn][maxn],w[maxn];
pair<int,int> pre[maxn][maxn];
struct Trie{
int son[10],fail;
double cnt,sum;
}trie[maxn];
void insert(char s[],int x){
int len = strlen(s);
int u = 0;
for(int i = 0;i < len;i++){
int v = s[i]-'0';
if(!trie[u].son[v]) trie[u].son[v] = ++num;
u = trie[u].son[v];
}
trie[u].cnt++;
trie[u].sum += w[x];
}
void getf(){
queue <int> q;
for(int i = 0;i < 10;i++)
if(trie[0].son[i]){
trie[trie[0].son[i]].fail = 0;
q.push(trie[0].son[i]);
}
while(!q.empty()){
int u = q.front();
q.pop();
trie[u].cnt += trie[trie[u].fail].cnt;
trie[u].sum += trie[trie[u].fail].sum;
for(int i = 0;i < 10;i++){
if(trie[u].son[i]){
trie[trie[u].son[i]].fail = trie[trie[u].fail].son[i];
q.push(trie[u].son[i]);
}
else
trie[u].son[i] = trie[trie[u].fail].son[i];
}
}
}
bool check(){
memset(f,-INF,sizeof f);
f[0][0] = 0;
for(int i = 0;i <= num;i++)
val[i] = trie[i].sum - trie[i].cnt * mid;
for(int i = 1;i <= n;i++)
for(int j = 0;j <= num;j++){
if(f[i-1][j] == -INF) continue;
if(t[i] == '.')
for(int k = 0;k < 10;k++){
int v = trie[j].son[k];
if(f[i][v] < f[i-1][j] + val[v]){
f[i][v] = f[i-1][j] + val[v];
pre[i][v] = make_pair(k,j);
}
}
else{
int k = t[i]-'0';
int v = trie[j].son[k];
if(f[i][v] < f[i-1][j] + val[v]){
f[i][v] = f[i-1][j] + val[v];
pre[i][v] = make_pair(k,j);
}
}
}
now = 0;
for(int i = 0;i <= num;i++)
if(f[n][i] > f[n][now]) now = i;
return (f[n][now] > 0);
}
void update(){
for(int i = n;i >= 1;i--){
T[i] = pre[i][now].first;
now = pre[i][now].second;
}
}
int main(){
scanf("%d%d",&n,&m);
scanf("%s",t+1);
for(int i = 1;i <= m;i++){
scanf("%s%lf",s,&w[i]);
w[i] = log(w[i]);
insert(s,i);
}
getf();
l = 0, r = log(INF);
while(r-l > eps){
mid = (l+r)*0.5;
if(check()) {
l = mid;
update();
}
else r = mid;
}
for(int i = 1; i <= n;i++)
printf("%c",T[i]+'0');
printf("\n");
return 0;
}