[题解] AC自动机

AC自动机用于多模式串与文本串匹配,方法是在当前点失配后跳转到最优的下一个匹配位置,即当前匹配位置的最长后缀,避免信息不必要的多次访问

TJOI 2013 单词

把单词拼起来,两两间隔一个非字母字符,然后跑放AC自动机上跑

在匹配时将匹配位置num++,在匹配结束后将所有节点的num传到其fail一直到根节点

代码
#include<bits/stdc++.h>
using namespace std;
const int N=2000011;
struct trie{
	bool jd;
	char x;
	int son[28];
	int num,fa;
}tre[N];
char tx[N],ms[N];
char ch1[2]="{";
int n;
int ans;
int lenx;
int tot=1,root=1;
int sum[N];
int jud[N];
int fail[N];
queue<int> dui;
inline int read()
{
	int s=0;
	char ch=getchar();
	while(ch>'9'||ch<'0')
		ch=getchar();
	while(ch>='0'&&ch<='9')
	{
		s=(s<<1)+(s<<3)+(ch^48);
		ch=getchar();
	}
	return s;
}
void insert(int x)
{
	int u=root;
	for(int i=0;i<lenx;i++)
	{
		if(tre[u].son[tx[i]-'a'+1])
			u=tre[u].son[tx[i]-'a'+1];
		else
		{
			tre[u].son[tx[i]-'a'+1]=++tot;
			tre[tot].fa=u;
			tre[tot].x=tx[i];
			u=tot;
		}
	}
	tre[u].jd=1;
	sum[x]=u;
	return;
}
void make_fail()
{
	int u=root;
	for(int i=1;i<=27;i++)
	{
		if(tre[u].son[i])
		{
			fail[tre[u].son[i]]=root;
			dui.push(tre[u].son[i]);
		}
		else
			tre[u].son[i]=root;
	}
	while(dui.size())
	{
		int x=dui.front();
		dui.pop();
		for(int i=1;i<=27;i++)
		{
			if(tre[x].son[i])
			{
				fail[tre[x].son[i]]=tre[fail[x]].son[i];
				dui.push(tre[x].son[i]);
			}
			else
				tre[x].son[i]=tre[fail[x]].son[i];
		}
	}
	return;
}
void query()
{
	int u=root;
	for(int i=0;i<lenx;i++)
	{
		int k=tre[u].son[ms[i]-'a'+1];
		jud[k]++;
		u=tre[u].son[ms[i]-'a'+1];
	}
	for(int i=2;i<=tot;i++)
	{
		if(!jud[i])
			continue;
		int k=fail[i];
		if(tre[i].jd)
			tre[i].num+=jud[i];
		while(k>1)
		{
			if(tre[k].jd)
				tre[k].num+=jud[i];
			k=fail[k];
		}
	}
	return;
}
void init()
{
	for(int i=1;i<=tot;i++)
	{
		fail[i]=0;
		jud[i]=0;
		for(int j=1;j<=27;j++)
			tre[i].son[j]=0;
		tre[i].fa=0;
		tre[i].x='\0';
		tre[i].jd=0;
		tre[i].num=0;
	}
	tot=1;
	return;
}
int main()
{
	n=read();
	ans=0;
	for(int i=1;i<=n;i++)
	{
		scanf("%s",tx);
		lenx=strlen(tx);
		strcat(ms,tx);
		strcat(ms,ch1);
		insert(i);
	}
	lenx=strlen(ms);
	make_fail();
	query();
	for(int i=1;i<=n;i++)
		printf("%d\n",tre[sum[i]].num);
	return 0;
}

\(\\\)
\(\\\)
\(\\\)

POI 2000 病毒

把所有病毒串放AC自动机上,然后在每个节点记录其是否为病毒串,每个节点不断用它的fail更新自己

然后在trie图上找环,满足环上没有病毒串标记

代码
#include<bits/stdc++.h>
using namespace std;
const int N=100011;
struct trie{
	int jud;
	bool liv;
	int son[2];
}tre[N];
char tx[N];
int n;
int lenx;
int root=1;
int tot=1;
int fail[N];
int sl;
queue<int> dui;
inline int read()
{
	int s=0;
	char ch=getchar();
	while(ch>'9'||ch<'0')
		ch=getchar();
	while(ch>='0'&&ch<='9')
	{
		s=(s<<1)+(s<<3)+(ch^48);
		ch=getchar();
	}
	return s;
}
void insert(int x)
{
	int u=root;
	for(int i=1;i<=lenx;i++)
	{
		if(tre[u].son[tx[i]-'0'])
			u=tre[u].son[tx[i]-'0'];
		else
		{
			tre[u].son[tx[i]-'0']=++tot;
			u=tot;
		}
	}
	tre[u].liv=1;
	return;
}
void make_fail()
{
	int u=root;
	fail[u]=u;
	for(int i=0;i<=1;i++)
	{
		if(tre[u].son[i])
		{
			fail[tre[u].son[i]]=u;
			dui.push(tre[u].son[i]);
		}
		else
			tre[u].son[i]=u;
	}
	while(dui.size())
	{
		int x=dui.front();
		dui.pop();
		for(int i=0;i<=1;i++)
		{
			if(tre[x].son[i])
			{
				fail[tre[x].son[i]]=tre[fail[x]].son[i];
				dui.push(tre[x].son[i]);
			}
			else
				tre[x].son[i]=tre[fail[x]].son[i];
		}
	}
	return;
}
void solve()
{
	for(int i=2;i<=tot;i++)
	{
		int x=fail[i];
		if(tre[i].liv)
			continue;
		while(x>1)
		{
			tre[i].liv|=tre[x].liv;
			x=fail[x];
		}
	}
	return;
}
bool judge(int x)
{
	if(tre[tre[x].son[0]].liv&&tre[tre[x].son[1]].liv)
		return 0;
	bool jd=0;
	tre[x].jud=1;
	if(!tre[tre[x].son[1]].liv)
	{
		if(tre[tre[x].son[1]].jud)
			return 1;
		else
			jd|=judge(tre[x].son[1]);
	}
	if(!tre[tre[x].son[0]].liv)
	{
		if(tre[tre[x].son[0]].jud)
			return 1;
		else
			jd|=judge(tre[x].son[0]);
	}
	if(!jd)
	{
		tre[x].jud=0;
		tre[x].liv=1;
	}
	return jd;
}
int main()
{
	n=read();
	for(int i=1;i<=n;i++)
	{
		cin>>tx+1;
		lenx=strlen(tx+1);
		insert(i);
	}
	make_fail();
	solve();
	if(judge(1))
		printf("TAK\n");
	else
		printf("NIE\n");
	return 0;
}

\(\\\)
\(\\\)
\(\\\)

\(\\\)
\(\\\)
\(\\\)

HNOI2006 最短母串

\(f[i][j][k]\)表示当前选的集合为k,最左端字符串为i,最右端字符串为j的最优字符串

转移的话就考虑向左端加一个串,右端加一个串,然后计算出转移后的长度和字符串取最小值就行了

其实不用往左右加,直接向右加就能包括所有情况,做的时候没想到

需要注意的是,左右端的串不能是新加入的串的字串,也就是说,预处理把所有具有包含关系的字符串留下最大的那个就行了

状压储存节点覆盖关系,跳fail匹配时可以直接判断某个x的后缀是否同时是y的前缀

代码
#include<bits/stdc++.h>
using namespace std;
const int N=611;
struct trie{
	int jds[27];
	int son[27];
	int fm;
	int sta;
}tre[N];
string f[13][13][1<<13];
struct hf{
	int i,j;
}a;
bool jud[N];
string ch[14];
int n;
int num;
int tot=1,root=1;
int len[N];
int dep[N];
int leg[N];
int fail[N];
int loc[N];
vector<hf> vct[1<<13];
queue<int> dui;
inline int read()
{
	int s=0;
	char ch=getchar();
	while(ch>'9'||ch<'0')
		ch=getchar();
	while(ch>='0'&&ch<='9')
	{
		s=(s<<1)+(s<<3)+(ch^48);
		ch=getchar();
	}
	return s;
}
void insert(int x)
{
	int u=root;
	for(int i=0;i<len[x];i++)
	{
		if(tre[u].son[ch[x][i]-'A'+1])
			u=tre[u].son[ch[x][i]-'A'+1];
		else
		{
			tre[u].son[ch[x][i]-'A'+1]=++tot;
			tre[u].jds[ch[x][i]-'A'+1]=1;
			u=tot;
		}
	}
	tre[u].fm=x;
	loc[x]=u;
	return;
}
void make_fail()
{
	int u=root;
	for(int i=1;i<=26;i++)
	{
		if(tre[u].son[i])
		{
			fail[tre[u].son[i]]=root;
			dui.push(tre[u].son[i]);
		}
		else
			tre[u].son[i]=root;
	}
	while(dui.size())
	{
		int x=dui.front();
		dui.pop();
		for(int i=1;i<=26;i++)
		{
			if(tre[x].son[i])
			{
				fail[tre[x].son[i]]=tre[fail[x]].son[i];
				dui.push(tre[x].son[i]);
			}
			else
				tre[x].son[i]=tre[fail[x]].son[i];
		}
	}
	return;
}
int mch(int x,int y)
{
	x=leg[x];
	y=leg[y];
	int k=loc[x];
	int lenx=0;
	while(k>1)
	{
		if(dep[k]<=ch[y].length()&&(tre[k].sta&(1<<y-1)))
		{
			lenx=dep[k];
			break;
		}
		k=fail[k];
	}
	return lenx;
}
void dfs(int x)
{
	if(tre[x].fm)
		tre[x].sta|=1<<tre[x].fm-1;
	for(int i=1;i<=26;i++)
		if(tre[x].jds[i])
		{
			dep[tre[x].son[i]]=dep[x]+1;
			dfs(tre[x].son[i]);
			if(tre[tre[x].son[i]].fm)
			{
				tre[x].fm=tre[tre[x].son[i]].fm;
				tre[x].sta|=tre[tre[x].son[i]].sta;
			}
		}
	jud[tre[x].fm]=1;
	return;
}
int main()
{
	n=read();
	for(int i=1;i<=n;i++)
	{
		cin>>ch[i];
		len[i]=ch[i].length();
		insert(i);
	}
	make_fail();
	dfs(1);
	for(int i=1;i<=n;i++)
		if(jud[i])
		{
			leg[++num]=i;
			f[num][num][1<<num-1]=ch[i];
		}
	int M=1<<num;
	for(int i=1;i<M;i++)
		for(int j=1;j<=num;j++)
			for(int k=1;k<=num;k++)
				if(((1<<j-1)&i)&&((1<<k-1)&i))
				{
					a.i=j;
					a.j=k;
					vct[i].push_back(a);
				}
	for(int k=1;k<M;k++)
		for(int v=0;v<vct[k].size();v++)
		{
			int i=vct[k][v].i;
			int j=vct[k][v].j;
			if(f[i][j][k].length())
				for(int h=1;h<=num;h++)
					if(!((1<<h-1)&k))
					{
						int len1=len[leg[h]]-mch(h,i);
						int len2=len[leg[h]]-mch(j,h);
						string ch1=ch[leg[h]].substr(0,len1);
						string ch2=ch[leg[h]].substr(len[leg[h]]-len2,len2);
						int sta=(1<<h-1)|k;
						if((f[h][j][sta].length()>f[i][j][k].length()+len1)||!f[h][j][sta].length()||(f[h][j][sta].length()==f[i][j][k].length()+len1&&(f[i][j][k]+ch1<f[h][j][sta])))
							f[h][j][sta]=ch1+f[i][j][k];
						if((f[i][h][sta].length()>f[i][j][k].length()+len2)||!f[i][h][sta].length()||(f[i][h][sta].length()==f[i][j][k].length()+len2&&(f[i][j][k]+ch2<f[i][h][sta])))
							f[i][h][sta]=f[i][j][k]+ch2;
					}
		}
	int minlen=0x7fffffff;
	string ch1="\0";
	for(int i=1;i<=n;i++)
		for(int j=1;j<=n;j++)
			if(f[i][j][M-1].length()!=0)
			{
				if(f[i][j][M-1].length()<minlen)
				{
					minlen=f[i][j][M-1].length();
					ch1=f[i][j][M-1];
				}
				else if(f[i][j][M-1].length()==minlen&&f[i][j][M-1]<ch1)
					ch1=f[i][j][M-1];
			}
	cout<<ch1<<endl;
	return 0;
}

\(\\\)
\(\\\)
\(\\\)

JSOI 2007文本生成器

f[i][j]表示文章生成到第i位,trie图上节点编号为j的非法方案数,g[i]表示文章生成到第i位的合法方案数

转移的话,新加入一个字符等于在trie上找儿子

如果儿子是合法的,\(g[i+1]+=f[i][j]\)

如果儿子非法,\(f[i+1][son[j]]+=f[i][j]\)

g的转移\(g[i+1]+=g[i]*26\)

答案就是g[m]

需要注意的是,要用一个串的子串更新自己合法,直接跳fail更新就行了

代码
#include<bits/stdc++.h>
using namespace std;
const int N=111;
const int mod=1e4+7;
struct trie{
	bool liv;
	int son[27];
}tre[60*N];
char ch[65][N];
int n;
int root=1,tot=1;
int g[N];
int fail[60*N];
int f[N][60*N];
int lenx,len[N];
queue<int> dui;
inline int read()
{
	int s=0;
	char ch=getchar();
	while(ch>'9'||ch<'0')
		ch=getchar();
	while(ch>='0'&&ch<='9')
	{
		s=(s<<1)+(s<<3)+(ch^48);
		ch=getchar();
	}
	return s;
}
void insert(int x)
{
	int u=root;
	for(int i=1;i<=len[x];i++)
	{
		if(tre[u].son[ch[x][i]-'A'+1])
			u=tre[u].son[ch[x][i]-'A'+1];
		else
		{
			tre[u].son[ch[x][i]-'A'+1]=++tot;
			u=tot;
		}
	}
	tre[u].liv=1;
	return;
}
void make_fail()
{
	int u=root;
	for(int i=1;i<=26;i++)
	{
		if(tre[u].son[i])
		{
			fail[tre[u].son[i]]=root;
			dui.push(tre[u].son[i]);
		}
		else
			tre[u].son[i]=root;
	}
	while(!dui.empty())
	{
		int x=dui.front();
		dui.pop();
		for(int i=1;i<=26;i++)
		{
			if(tre[x].son[i])
			{
				fail[tre[x].son[i]]=tre[fail[x]].son[i];
				dui.push(tre[x].son[i]);
			}
			else
				tre[x].son[i]=tre[fail[x]].son[i];
		}
	}
	return;
}
void update()
{
	for(int i=2;i<=tot;i++)
	{
		if(tre[i].liv)
			continue;
		int k=fail[i];
		while(k>1)
		{
			tre[i].liv|=tre[k].liv;
			if(tre[i].liv)
				break;
			k=fail[k];
		}
	}
	return;
}
int main()
{
	n=read();
	lenx=read();
	for(int i=1;i<=n;i++)
	{
		cin>>ch[i]+1;
		len[i]=strlen(ch[i]+1);
		insert(i);
	}
	make_fail();
	update();
	f[0][1]=1;
	for(int i=0;i<=lenx-1;i++)
	{
		(g[i+1]+=g[i]*26%mod)%=mod;
		for(int j=1;j<=tot;j++)
			if(f[i][j])
				for(int k=1;k<=26;k++)
				{
					if(tre[tre[j].son[k]].liv)
						(g[i+1]+=f[i][j])%=mod;
					else
						(f[i+1][tre[j].son[k]]+=f[i][j])%=mod;
				}
	}
	cout<<g[lenx]<<endl;
	return 0;
}

\(\\\)
\(\\\)
\(\\\)

bzoj2905背单词

剪枝大法好!!!

设f[i]表示选到的单词在trie上的点为i,强制选这个单词的最大收益

\[f[i]=max(f[i的祖先的所有fail])+w[i] \]

因为一直跳父亲的fail会有重复访问,标记一下访问过谁,碰见有标记的就break(比正解还快

这个方法的复杂度是\(O(n(每个字符串的子串数量))\),挺悬乎的

正解是线段树维护fail树的dfs序优化dp

每次更新这个字符串在fail树上的子树,更新自己时在trie上跳父亲query取max

复杂度\(O(len \log len)\)

代码 朴素dp加剪枝优化
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=1e6+11;
struct trie{
	int son[27];
}tre[2*N];
bool jud[N];
char tx[N];
int n,lenx,num;
int fail[N],fa[N];
int root=1,tot=1;
int w[N];
int f[N];
int now[N];
int sum[N];
queue<int> dui;
inline int read()
{
	int s=0,w=1;
	char ch=getchar();
	while(ch>'9'||ch<'0')
	{
		if(ch=='-')
			w=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		s=(s<<1)+(s<<3)+(ch^48);
		ch=getchar();
	}
	return s*w;
}
void insert(int x,int w)
{
	if(w<=0)
		return;
	int u=root;
	for(int i=1;i<=lenx;i++)
	{
		if(tre[u].son[tx[i]-'a'+1])
			u=tre[u].son[tx[i]-'a'+1];
		else
		{
			tre[u].son[tx[i]-'a'+1]=++tot;
			fa[tot]=u;
			u=tot;
		}
	}
	now[x]=u;
	return;
}
void make_fail()
{
	int u=root;
	for(int i=1;i<=26;i++)
	{
		if(tre[u].son[i])
		{
			fail[tre[u].son[i]]=u;
			dui.push(tre[u].son[i]);
		}
		else
			tre[u].son[i]=u;
	}
	while(dui.size())
	{
		int x=dui.front();
		dui.pop();
		for(int i=1;i<=26;i++)
		{
			if(tre[x].son[i])
			{
				fail[tre[x].son[i]]=tre[fail[x]].son[i];
				dui.push(tre[x].son[i]);
			}
			else
				tre[x].son[i]=tre[fail[x]].son[i];
		}
	}
	return;
}
void init()
{
	for(int i=1;i<=tot;i++)
	{
		memset(tre[i].son,0,sizeof(tre[i].son));
		fail[i]=0;
		fa[i]=0;
		f[i]=0;
		w[i]=0;
		now[i]=0;
	}
	tot=1;
	return;
}
signed main()
{
	int t=read();
	while(t--)
	{
		init();
		n=read();
		for(int i=1;i<=n;i++)
		{
			scanf("%s",tx+1);
			lenx=strlen(tx+1);
			insert(i,w[i]=read());
		}
		make_fail();
		int fmax=-0x7fffffffff;
		for(int i=1;i<=n;i++)
		{
			fmax=max(fmax,w[i]);
			if(!now[i])
				continue;
			for(int j=1;j<=num;j++)
			{
				jud[sum[j]]=0;
				sum[j]=0;
			}
			num=0;
			int maxx=0;
			for(int j=now[i];j>1;j=fa[j])
				for(int k=j;k>1;k=fail[k])
				{
					if(jud[k])
						break;
					jud[sum[++num]=k]=1;
					maxx=max(maxx,f[k]);
				}
			f[now[i]]=w[i]+maxx;
		}
		if(fmax<0)
		{
			cout<<fmax<<endl;
			continue;
		}
		int maxx=0;
		for(int i=1;i<=tot;i++)
			maxx=max(maxx,f[i]);
		cout<<maxx<<endl;
	}
	return 0;
}
代码 线段树优化dp
#include<bits/stdc++.h>
using namespace std;
const int N=3e5+11;
struct tree{
	int l,r;
	int sum;
	int lazy;
}tre_[5*N];
struct trie{
	int son[27];
}tre[N];
struct qxxx{
	int v,next;
}cc[2*N];
char tx[N];
int lenx,n,st;
int tot=1,root=1;
int w[N];
int fa[N];
int now[N];
int dep[N];
int size[N];
int fail[N];
int dfn[N],rle[N];
int first[N],cnt;
queue<int> dui;
inline int read()
{
	int s=0,w=1;
	char ch=getchar();
	while(ch>'9'||ch<'0')
	{
		if(ch=='-')
			w=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		s=(s<<1)+(s<<3)+(ch^48);
		ch=getchar();
	}
	return s*w;
}
inline int max_(int a,int b)
{
	return a<b?b:a;
}
void insert(int x,int w)
{
	if(w<=0)
		return;
	int u=root;
	for(int i=1;i<=lenx;i++)
	{
		if(tre[u].son[tx[i]-'a'+1])
			u=tre[u].son[tx[i]-'a'+1];
		else
		{
			tre[u].son[tx[i]-'a'+1]=++tot;
			fa[tot]=u;
			u=tot;
		}
	}
	now[x]=u;
	return;
}
void qxx(int u,int v)
{
	cc[++cnt].v=v;
	cc[cnt].next=first[u];
	first[u]=cnt;
	return;
}
void make_fail()
{
	int u=root;
	fail[u]=u;
	for(int i=1;i<=26;i++)
	{
		if(tre[u].son[i])
		{
			fail[tre[u].son[i]]=u;
			qxx(tre[u].son[i],u);
			qxx(u,tre[u].son[i]);
			dui.push(tre[u].son[i]);
		}
		else
			tre[u].son[i]=u;
	}
	while(dui.size())
	{
		int x=dui.front();
		dui.pop();
		for(int i=1;i<=26;i++)
		{
			if(tre[x].son[i])
			{
				fail[tre[x].son[i]]=tre[fail[x]].son[i];
				qxx(tre[x].son[i],tre[fail[x]].son[i]);
				qxx(tre[fail[x]].son[i],tre[x].son[i]);
				dui.push(tre[x].son[i]);
			}
			else
				tre[x].son[i]=tre[fail[x]].son[i];
		}
	}
	return;
}
void dfs(int x)
{
	size[x]=1;
	dfn[x]=++st;
	rle[st]=x;
	for(int i=first[x];i;i=cc[i].next)
		if(!dfn[cc[i].v])
		{
			dep[cc[i].v]=dep[x]+1;
			dfs(cc[i].v);
			size[x]+=size[cc[i].v];
		}
	return;
}
void build(int i,int l,int r)
{
	tre_[i].l=l;
	tre_[i].r=r;
	tre_[i].lazy=tre_[i].sum=0;
	if(l==r)
		return;
	int mid=(l+r)>>1;
	build(i<<1,l,mid);
	build(i<<1|1,mid+1,r);
	return;
}
void pushdown(int i)
{
	if(tre_[i].l==tre_[i].r)
	{
		tre_[i].sum=max_(tre_[i].sum,tre_[i].lazy);
		return;	
	}
	tre_[i<<1].lazy=max_(tre_[i<<1].lazy,tre_[i].lazy);
	tre_[i<<1|1].lazy=max_(tre_[i<<1|1].lazy,tre_[i].lazy);
	return;
}
int query(int i,int x)
{
	if(tre_[i].lazy)
		pushdown(i);
	if(tre_[i].l==tre_[i].r)
		return tre_[i].sum;
	int mid=(tre_[i].l+tre_[i].r)>>1;
	if(mid>=x)
		return query(i<<1,x);
	else
		return query(i<<1|1,x);
	return 0;
}
void insert(int i,int l,int r,int sum)
{
	if(tre_[i].lazy)
		pushdown(i);
	if(tre_[i].l>=l&&tre_[i].r<=r)
	{
		tre_[i].lazy=sum;
		return;
	}
	int mid=(tre_[i].l+tre_[i].r)>>1;
	if(mid>=l)
		insert(i<<1,l,r,sum);
	if(mid<r)
		insert(i<<1|1,l,r,sum);
	return;
}
int query_max(int i)
{
	if(tre_[i].lazy)
		pushdown(i);
	if(tre_[i].l==tre_[i].r)
		return tre_[i].sum;
	return max_(query_max(i<<1),query_max(i<<1|1));
}
void init()
{
	for(int i=1;i<=tot;i++)
	{
		fail[i]=0;
		w[i]=0;
		fa[i]=0;
		now[i]=0;
		fa[i]=0;
		dep[i]=0;
		first[i]=0;
		size[i]=0;
		dfn[i]=0;
		rle[i]=0;
		cc[i].next=cc[i].v=0;
		memset(tre[i].son,0,sizeof(tre[i].son));
	}
	st=0;
	cnt=0;
	tot=1;
	return;
}
signed main()
{
	int t=read();
	while(t--)
	{
		init();
		n=read();
		for(int i=1;i<=n;i++)
		{
			scanf("%s",tx+1);
			lenx=strlen(tx+1);
			insert(i,w[i]=read());
		}
		make_fail();
		dep[1]=1;
		dfs(1);
		build(1,1,tot);
		for(int i=1;i<=n;i++)
		{
			if(!now[i])
				continue;
			int maxx=0;
			for(int j=now[i];j>1;j=fa[j])
				maxx=max(maxx,query(1,dfn[j]));
			insert(1,dfn[now[i]],dfn[now[i]]+size[now[i]]-1,maxx+w[i]);
		}
		cout<<query_max(1)<<endl;
	}
	return 0;
}

\(\\\)
\(\\\)
\(\\\)

JSOI2009密码

设f[i][j][s]表示填到i位,在trie上节点为j,已包含的密码集合为s的方案数

\[f[i+1][tre[j][p]][s']+=f[i][j][s] \]

答案就是\(\sum^{tot}_{i=1}f[len][i][全集]\)

复杂度\(O(26*tot*len*2^{n})\)

至于输出方案,用vector记录过程中的字符串不行,记前趋也不行。因为没用的状态太多,不仅浪费空间,vector申请内存还贼慢

一个方法是暴搜,模拟刚才的dp,但是从最终状态往前搜,判断一下搜到的状态是否能到达当前状态,能的话更新一下字符串接着往前搜,复杂度不会证

还有一个办法,考虑这个方案数为何是42;当只有1个字符串,且可以随便插一个字母时,方案数是26*2>42,也就是说,字符串肯定是一个个拼起来的,也是搜,这个复杂度是\(O(n!)\),可以通过

代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=1511;
struct trie{
	int son[27];
	int liv;
}tre[N];
char tx[N];
int lenx;
int n,l,num;
int root=1,tot=1;
int f[28][N][N];
int fail[N];
string st[N];
queue<int> dui;
inline int read()
{
	int s=0;
	char ch=getchar();
	while(ch>'9'||ch<'0')
		ch=getchar();
	while(ch>='0'&&ch<='9')
	{
		s=(s<<1)+(s<<3)+(ch^48);
		ch=getchar();
	}
	return s;
}
void insert(int x)
{
	int u=1;
	for(int i=1;i<=lenx;i++)
	{
		if(tre[u].son[tx[i]-'a'+1])
			u=tre[u].son[tx[i]-'a'+1];
		else
		{
			tre[u].son[tx[i]-'a'+1]=++tot;
			u=tot;
		}
	}
	tre[u].liv|=1<<x-1;
	return;
}
void make_fail()
{
	int u=root;
	for(int i=1;i<=26;i++)
	{
		if(tre[u].son[i])
		{
			fail[tre[u].son[i]]=u;
			dui.push(tre[u].son[i]);
		}
		else
			tre[u].son[i]=u;
	}
	while(dui.size())
	{
		u=dui.front();
		dui.pop();
		for(int i=1;i<=26;i++)
		{
			if(tre[u].son[i])
			{
				fail[tre[u].son[i]]=tre[fail[u]].son[i];
				dui.push(tre[u].son[i]);
			}
			else
				tre[u].son[i]=tre[fail[u]].son[i];
		}
	}
	return;
}
void get_leg()
{
	for(int i=2;i<=tot;i++)
	{
		int k=fail[i];
		while(k>1)
		{
			tre[i].liv|=tre[k].liv;
			k=fail[k];
		}
	}
	return;
}
void dfs(int i,int j,int k,string a)
{
	if(!i)
	{
		st[++num]=a;
		return;
	}
	for(int h=1;h<=tot;h++)
		for(int y=0;y<(1<<n);y++)
			for(int x=1;x<=26;x++)
				if(tre[h].son[x]==j&&((y|(tre[j].liv))==k)&&f[i-1][h][y])
				{
					char ch=x+'a'-1;
					dfs(i-1,h,y,ch+a);
				}
	return;
}
signed main()
{
	l=read();
	n=read();
	for(int i=1;i<=n;i++)
	{
		cin>>tx+1;
		lenx=strlen(tx+1);
		insert(i);
	}
	make_fail();
	get_leg();
	int M=1<<n;
	f[0][1][0]=1;
	for(int i=0;i<l;i++)
		for(int j=1;j<=tot;j++)
			for(int k=0;k<M;k++)
				for(int h=1;h<=26;h++)
					f[i+1][tre[j].son[h]][k|tre[tre[j].son[h]].liv]+=f[i][j][k];
	int ans=0;
	for(int i=1;i<=tot;i++)
		ans+=f[l][i][M-1];
	cout<<ans<<endl;
	if(ans<=42)
	{
		for(int i=1;i<=tot;i++)
			if(f[l][i][M-1])
				dfs(l,i,M-1,"");
		sort(st+1,st+num+1);
		for(int i=1;i<=num;i++)
			cout<<st[i]<<endl;
	}
	return 0;
}

\(\\\)
\(\\\)
\(\\\)

BJWC2011禁忌

这个跟之前的区别的是,把串断开,每个字母只属于一个串

解决方法也很简单,把禁忌串的所有儿子指向1的所有儿子

设f[i][j]表示当前串长为i,在trie上节点为j的伤害,只有它不能转移,原因是每次有若干串增加伤害,需要知道串的个数

所以设siz[i][j]表示在这个时候的串个数,转移

\(f[i+1][tre[j][p]]+=f[i][j]+siz[i][j],tre[j][p]\)是串的末尾
\(f[i+1][tre[j][p]]+=f[i][j],tre[j][p]\)不是串的末尾
\(siz[i+1][tre[j][p]]+=siz[i][j]\)

然后矩阵快速幂,填系数时需要注意同一个位置可能会填多次,不能只赋为1

代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define double long double
const int N=111;
double sl;
struct trie{
	int son[27];
	bool jds[27];
	bool liv;
}tre[N];
struct mat_{
	int h,l;
	double mat[151][151];
	friend mat_ operator*(mat_ a,mat_ b)
	{
		mat_ c;
		memset(&c,0,sizeof(c));
		c.h=a.h;
		c.l=b.l;
		for(int i=1;i<=a.h;i++)
			for(int j=1;j<=b.l;j++)
			{
				for(int k=1;k<=a.l;k++)
					c.mat[i][j]+=a.mat[i][k]*b.mat[k][j];
				c.mat[i][j]/=sl;
			}
		return c;
	}
}mt,hs;
char tx[N];
int n,lenx,len;
int root=1,tot=1;
int fail[N];
queue<int> dui;
inline int read()
{
	int s=0;
	char ch=getchar();
	while(ch>'9'||ch<'0')
		ch=getchar();
	while(ch>='0'&&ch<='9')
	{
		s=(s<<1)+(s<<3)+(ch^48);
		ch=getchar();
	}
	return s;
}
void insert()
{
	int u=root;
	for(int i=1;i<=lenx;i++)
	{
		if(tre[u].son[tx[i]-'a'+1])
		{
			u=tre[u].son[tx[i]-'a'+1];
			if(tre[u].liv)
				return;
		}
		else
		{
			tre[u].son[tx[i]-'a'+1]=++tot;
			tre[u].jds[tx[i]-'a'+1]=1;
			u=tot;
		}
	}
	tre[u].liv=1;
	return;
}
void make_fail()
{
	int u=root;
	for(int i=1;i<=sl;i++)
	{
		if(tre[u].son[i])
		{
			fail[tre[u].son[i]]=u;
			dui.push(tre[u].son[i]);
		}
		else
			tre[u].son[i]=u;
	}
	while(dui.size())
	{
		int x=dui.front();
		dui.pop();
		for(int i=1;i<=sl;i++)
		{
			if(tre[x].son[i])
			{
				fail[tre[x].son[i]]=tre[fail[x]].son[i];
				tre[tre[x].son[i]].liv|=tre[tre[fail[x]].son[i]].liv;
				dui.push(tre[x].son[i]);
			}
			else
				tre[x].son[i]=tre[fail[x]].son[i];
		}
	}
	return;
}
void init()
{
	mt.h=2*tot;
	mt.l=2*tot;
	hs.h=2*tot;
	hs.l=1;
	hs.mat[1][1]=1;
	for(int i=1;i<=tot;i++)
	{
		for(int j=1;j<=sl;j++)
		{
			mt.mat[tre[i].son[j]][i]+=1.0;
			mt.mat[tot+tre[i].son[j]][i+tot]+=1.0;
			mt.mat[tot+tre[i].son[j]][i]+=tre[tre[i].son[j]].liv;
		}
	}
	return;
}
void fma()
{
	while(len)
	{
		if(len&1)
			hs=mt*hs;
		mt=mt*mt;
		len>>=1;
	}
	return;
}
signed main()
{
	n=read();
	len=read();
	sl=read();
	for(int i=1;i<=n;i++)
	{
		cin>>tx+1;
		lenx=strlen(tx+1);
		insert();
	}
	make_fail();
	for(int j=2;j<=tot;j++)
	{
		if(!tre[j].liv)
			continue;
		for(int i=1;i<=sl;i++)
			tre[j].son[i]=tre[1].son[i];
	}
	init();
	fma();
	double ans=0;
	for(int i=1;i<=tot;i++)
		ans+=hs.mat[i+tot][1];
	printf("%.17Lf",ans);
	return 0;
}
posted @ 2021-08-20 21:12  sitiy  阅读(18)  评论(0编辑  收藏  举报