插头dp初探

问题描述

插头dp用于解决一类可基于图连通性递推的问题。用插头来表示轮廓线上的连通性,然后根据连通性与下一位结合讨论进行转移。

表示连通性的方法

<最小表示法> 与字符串循环最小表示不同,这种方法用于给轮廓线上的联通情况确定一个唯一对应的标号序列,做法是从左至右轮廓线扫描,每扫描到一个未标号的位置就新建一个标号,并将轮廓线以后与这一位联通的位置都标上此号,不被包含的点标号为0。举例本质相同的连通性\((3,3,2,1,3)\)\((2,2,3,1,2)\)都会被标记为\(1,1,2,3,1\)

<括号表示法> 用于解决路径(、回路)相关的连通性。做法是将轮廓线上方的回路链接到轮廓向上插头区别为左插头与右插头,逐格转移时讨论格子上边有左边的插头;求解任意路径问题时保留左右插头但不合并,并且引入“独立插头”表示只有一端链接到轮廓线路径的链接端;

<其它> 引入一些插头,然后直接转换为进制数表示的不清楚怎么分类的方法。

具体操作

\(f[x,s]\)为考虑到位置\(x\) 轮廓线状态为\(s\)的解。转移是个费脑子的事情,按下不表。连通性转换为进制数时选择\(2^t\)作为进制数可以更快速的取出、修改轮廓线上某一位的值,但时需要把所有的状态扔进hash表里。

练习题 (7/7)

luogu5056 【模板】插头dp

我就不造轮子了 讲解可以参考ladylex 的例题2(虽然不是一道题,但分类讨论差不多的而且有图解)

#include <bits/stdc++.h>
#define LL long long

const int mod=299987;

int n,m,endx,endy;
LL ans;
char a[20][20];

struct hash_set {
	LL val[mod];
	int siz,key[mod],hsh[mod];
	void clear() {
		memset(val,0,sizeof val);
		memset(key,-1,sizeof key);
		memset(hsh,0,sizeof hsh);
		siz=0;
	}
	void newhsh(int id,int vl) {
		hsh[id]=++siz,key[siz]=vl;
	}
	LL&operator[](const int &sta) {
		for(int i=sta%mod; ;i=(i+1==mod?0:i+1)) {
			if(!hsh[i]) newhsh(i,sta);
			if(key[hsh[i]]==sta) return val[hsh[i]];
		}
	}
} f[2];

int find(int sta,int id) {
	return (sta>>((id-1)<<1))&3;
}
void set(int &sta,int bit,int val) {
	bit=(bit-1)<<1;
	sta|=3<<bit;
	sta^=3<<bit;
	sta|=val<<bit; 
}
int link(int sta,int pos) {
	int cnt=0,dlt=(find(sta,pos)==1?1:-1);
	for(int i=pos; i&&i<=m+1; i+=dlt) {
		int plg=find(sta,i);
		if(plg==1) cnt++;
		else if(plg==2) cnt--;
		if(!cnt) return i;
	}
	return -1;
}
void p_dp(int x,int y) {
	int now=((x-1)*m+y)&1,lst=now^1,tot=f[lst].siz;
	f[now].clear();
	for(int i=1; i<=tot; ++i) {
		int sta=f[lst].key[i];
		LL val=f[lst].val[i];
		if(link(sta,y)==-1||link(sta,y+1)==-1) 
			continue; // 状态不可用 
		int p1=find(sta,y),p2=find(sta,y+1);
		if(a[x][y]!='.') {
			if(!p1&&!p2) f[now][sta]+=val;
		} else if(!p1&&!p2) {
			if(a[x+1][y]=='.'&&a[x][y+1]=='.') {
				set(sta,y,1);
				set(sta,y+1,2);
				f[now][sta]+=val;
			}
		} else if(p1&&!p2) {
			if(a[x+1][y]=='.') f[now][sta]+=val;
			if(a[x][y+1]=='.') {
				set(sta,y,0);
				set(sta,y+1,p1);
				f[now][sta]+=val;
			}			
		} else if(!p1&&p2) {
			if(a[x][y+1]=='.') f[now][sta]+=val;
			if(a[x+1][y]=='.') {
				set(sta,y,p2);
				set(sta,y+1,0);
				f[now][sta]+=val;
			}
		} else if(p1==1&&p2==1) { // '((' ))
			set(sta,link(sta,y+1),1);
			set(sta,y,0);
			set(sta,y+1,0);
			f[now][sta]+=val;
		} else if(p1==1&&p2==2) { // '()'
			if(x==endx&&y==endy) ans+=val;
		} else if(p1==2&&p2==1) { // ')(' => merge
			set(sta,y,0);
			set(sta,y+1,0);
			f[now][sta]+=val;
		} else if(p1==2&&p2==2) { //(( '))'
			set(sta,link(sta,y),2);
			set(sta,y,0);
			set(sta,y+1,0);
			f[now][sta]+=val;
		}
	}
}

int main() {
	scanf("%d%d",&n,&m);
	for(int i=1; i<=n; ++i) {
		scanf("%s",a[i]+1); 
		for(int j=1; j<=m; ++j) {
			if(a[i][j]=='.') endx=i,endy=j;
		}
	}
	f[0].clear();
	f[0][0]=1;
	for(int i=1; i<=n; ++i) {
		for(int j=1; j<=m; ++j) p_dp(i,j);
		if(i!=n) {
			int now=(i*m)&1,tot=f[now].siz;
			for(int j=1; j<=tot; ++j) 
				f[now].key[j]<<=2;
		}
	}
	printf("%lld\n",ans);
	return 0;
}

luogu2289 邮递员

容易发现从\((1,1)\)出发再回到\((1,1)\)且所有点都恰号经过一次的方案数正是途中的曼哈顿回路数目*2(正着走和逆着走),高精度。

#include <bits/stdc++.h>
//using namespace std;

struct cint {
	static const int P=1e9;
	int bit[10];
	cint() { clear();}
	void clear() {
		memset(bit,0,sizeof bit);
	}
	void set(int t) {
		for(clear(); t; bit[++bit[0]]=t%P,t/=P);
	}
	int&operator[](const int &d) {
		return bit[d];
	}
	void print(char ed='\n') {
		printf("%d",bit[bit[0]]);
		for(int i=bit[0]-1; i>0; --i) printf("%09d",bit[i]);
		putchar(ed);
	}
	cint operator+(cint b) {
		cint c;
		c.clear();
		c[0]=std::max(bit[0],b[0])+1;
		for(int i=1; i<=c[0]; ++i) {
			c[i]+=bit[i]+b[i];
			c[i+1]+=c[i]/P;
			c[i]%=P;
		}
		while(!c[c[0]]) c[0]--;
		return c;
	}
	void operator+=(cint b) {
		*this=*this+b;
	}
	void operator=(int x) {
		set(x);
	}
} ans;
struct hash_map {
	static const int P=299987;
	cint val[P];
	int siz,key[P],hsh[P];
	void clear() {
		siz=0;
		memset(val,0,sizeof val);
		memset(key,-1,sizeof key);
		memset(hsh,0,sizeof hsh);
	}
	void new_hsh(int id,int vl) {
		hsh[id]=++siz,key[siz]=vl;
	}
	cint &operator[](const int &s) {
		for(int i=s%P; ; i=(i+1==P?0:i+1)) {
			if(!hsh[i]) new_hsh(i,s);
			if(key[hsh[i]]==s) return val[hsh[i]];
		} 
	}
} f[2];

int n,m;
int find(int sta,int id) {
	return (sta>>((id-1)<<1))&3;
}
void set(int&sta,int bit,int val) {
	bit=(bit-1)<<1;
	sta|=3<<bit;
	sta^=3<<bit;
	sta|=val<<bit;
}
int link(int sta,int pos) {
	int cnt=0,dlt=(find(sta,pos)==1?1:-1);
	for(int i=pos; i&&i<=m+1; i+=dlt) {
		int plg=find(sta,i);
		if(plg==1) cnt++;
		else if(plg==2) cnt--;
		if(!cnt) return i;
	}
	return -1;
}
void p_dp(int x,int y) {
	int now=((x-1)*m+y)&1,lst=now^1;
	f[now].clear();
	for(int i=1; i<=f[lst].siz; ++i) {
		int sta=f[lst].key[i];
		cint val=f[lst].val[i];
		if(link(sta,y)==-1||link(sta,y+1)==-1) continue;
		int p1=find(sta,y),p2=find(sta,y+1);
		if(!p1&&!p2) {
			if(x!=n&&y!=m) {
				set(sta,y,1);
				set(sta,y+1,2);
				f[now][sta]+=val;
			}
		} else if(p1&&!p2) {
			if(x!=n) f[now][sta]+=val;
			if(y!=m) {
				set(sta,y,0);
				set(sta,y+1,p1);
				f[now][sta]+=val;
			} 
		} else if(!p1&&p2) {
			if(y!=m) f[now][sta]+=val;
			if(x!=n) {
				set(sta,y,p2);
				set(sta,y+1,0);
				f[now][sta]+=val;
			} 
		} else if(p1==1&&p2==1) {
			set(sta,link(sta,y+1),1);
			set(sta,y,0);
			set(sta,y+1,0);
			f[now][sta]+=val;
		} else if(p1==1&&p2==2) {
			if(x==n&&y==m) ans+=val;
		} else if(p1==2&&p2==1) {
			set(sta,y,0);
			set(sta,y+1,0);
			f[now][sta]+=val;
		} else {
			set(sta,link(sta,y),2);
			set(sta,y,0);
			set(sta,y+1,0);
			f[now][sta]+=val;
		}
	}
}

int main() {
	scanf("%d%d",&n,&m);
	if(n==1||m==1) {
		puts("1");
		return 0;
	}
	if(n<m) std::swap(n,m);
	f[0].clear();
	f[0][0]=1;
	for(int i=1; i<=n; ++i) {
		for(int j=1; j<=m; ++j) p_dp(i,j);
		if(i!=n) {
			int now=(i*m)&1;
			for(int j=1; j<=f[now].siz; ++j) 
				f[now].key[j]<<=2;
		}
	}
	ans+=ans;
	ans.print();
	return 0;
}

bzoj2310 ParkII

求最大权任意路径,引入了独立插头

#include <bits/stdc++.h>
#define upd(sta,x) f[now][sta]=max(f[now][sta],(x))
using std::max;

struct hash_map {
	static const int P=23333;
	int siz,hsh[P],val[P],key[P]; 
	void clear() {
		siz=0;
		memset(hsh,0,sizeof hsh);
		memset(key,-1,sizeof key);
		memset(val,-0x3f,sizeof val);
	}
	void new_hsh(int id,int sta) {
		hsh[id]=++siz,key[siz]=sta;
	}
	int &operator[](int sta) {
		for(int i=sta%P; ; i=(i+1==P?0:i+1)) {
			if(!hsh[i]) new_hsh(i,sta);
			if(key[hsh[i]]==sta) return val[hsh[i]];
		}
	}
} f[2];

int n,m,ans=-0x3f3f3f3f,a[101][101];
int find(int sta,int id) {
	return (sta>>((id-1)<<1))&3;
}
void set(int &sta,int bit,int val) {
	bit=(bit-1)<<1;
	sta|=3<<bit;
	sta^=3<<bit;
	sta|=val<<bit;
}
int link(int sta,int pos) {
	int cnt=0,dlt=(find(sta,pos)==1?1:-1);
	for(int i=pos; i&&i<=m+1; i+=dlt) {
		int plg=find(sta,i);
		if(plg==1) cnt++;
		else if(plg==2) cnt--;
		if(cnt==0) return i;
	}
	return -1;
}
bool check(int sta) {
	int cnt=0,cnt1=0;
	for(int i=1; i<=m+1; ++i) {
		int plg=find(sta,i);
		if(plg==3) cnt++;
		else if(plg==1) cnt1++;
		else if(plg==2) cnt1--;
		if(cnt>2/*||cnt1<0*/) break;
	} 
	return cnt<=2&&cnt1==0;
} 
void p_dp(int x,int y) {
	int now=((x-1)*m+y)&1,lst=now^1;
	f[now].clear();
	for(int i=1; i<=f[lst].siz; ++i) {
		int sta=f[lst].key[i];
		int val=f[lst].val[i];
		if(!check(sta)||sta>=(1<<((m+1)<<1))) continue;
		int p1=find(sta,y);
		int p2=find(sta,y+1);
		int idl=sta;
		set(idl,y,0);
		set(idl,y+1,0);
		int ept1=idl,ept2=idl;
		if(!p1&&!p2) {
			upd(idl,val); //跳过这个格子 
			if(x<n&&y<m) set(sta,y,1),set(sta,y+1,2),upd(sta,val+a[x][y]); //新建一对括号 
			if(x<n) set(ept1,y,3),upd(ept1,val+a[x][y]); //新建向下的独立插头 
			if(y<m) set(ept2,y+1,3),upd(ept2,val+a[x][y]); //新建向右的独立插头
		} else if(p1&&!p2) {
			if(x<n) upd(sta,val+a[x][y]); //向下扩展p1 
			if(y<m) set(ept1,y+1,p1),upd(ept1,val+a[x][y]); //向右扩展p1 
			if(p1==3) {if(!idl) ans=max(ans,val+a[x][y]);} 
			else set(ept2,link(sta,y),3),upd(ept2,val+a[x][y]); //停止扩展p1,p1的另一头改为独立插头 
		} else if(!p1&&p2) { 
			if(y<m) upd(sta,val+a[x][y]); //向右扩展p2 
			if(x<n) set(ept2,y,p2),upd(ept2,val+a[x][y]); //向下扩展p2
			if(p2==3) {if(!idl) ans=max(ans,val+a[x][y]);} 
			else set(ept1,link(sta,y+1),3),upd(ept1,val+a[x][y]); //停止扩展p2,p2的另一头改为独立插头 
		} 
		else if(p1==1&&p2==1) set(ept1,link(sta,y+1),1),upd(ept1,val+a[x][y]); //'((')) 
		else if(p1==1&&p2==2) continue; //形成回路,不合法 
		else if(p1==2&&p2==1) upd(idl,val+a[x][y]); //(')(') 连接 
		else if(p1==2&&p2==2) set(ept2,link(sta,y),2),upd(ept2,val+a[x][y]); //(('))' 
		else if(p1==3&&p2==3) {if(!idl) ans=max(ans,val+a[x][y]);} 
		else if(p2==3) set(ept1,link(sta,y),3),upd(ept1,val+a[x][y]); //连接
		else if(p1==3) set(ept2,link(sta,y+1),3),upd(ept2,val+a[x][y]); //连接 
	}
}

int main() {
	scanf("%d%d",&n,&m);
	for(int i=1; i<=n; ++i) {
		for(int j=1; j<=m; ++j) {
			scanf("%d",&a[i][j]);
			ans=max(ans,a[i][j]);
		}
	}
	f[0].clear();
	f[0][0]=0;
	for(int i=1; i<=n; ++i) {
		for(int j=1; j<=m; ++j) p_dp(i,j);
		if(i!=n) {
			int now=(i*m)&1;
			for(int j=1; j<=f[now].siz; ++j) 
				f[now].key[j]<<=2;
		} 
	}
	printf("%d\n",ans);
	return 0;
}

bzoj2331 [SCOI2011]地板

轮廓线上的状态0表示无插头,1表示有一个没有拐弯的插头,2表示拐过弯的插头。

#include <bits/stdc++.h>
const int mod=20110520;

struct hash_map {
	static const int P=233333;
	int siz,hsh[P],val[P],key[P]; 
	void clear() {
		siz=0;
		memset(hsh,0,sizeof hsh);
		memset(key,-1,sizeof key);
		memset(val,0,sizeof val);
	}
	void new_hsh(int id,int sta) {
		hsh[id]=++siz,key[siz]=sta;
	}
	int &operator[](int sta) {
		for(int i=sta%P; ; i=(i+1==P?0:i+1)) {
			if(!hsh[i]) new_hsh(i,sta);
			if(key[hsh[i]]==sta) return val[hsh[i]];
		}
	}
} f[2];

int n,m,edx,edy,ans;
char a[102][102];

int find(int sta,int id) {
	return (sta>>((id-1)<<1))&3;
}
void set(int &sta,int bit,int val) {
	bit=(bit-1)<<1;
	sta|=3<<bit;
	sta^=3<<bit;
	sta|=val<<bit;
}
#define upd(val) f[now][sta]=(f[now][sta]+(val))%mod;
void p_dp(int x,int y) {
	int now=((x-1)*m+y)&1,lst=now^1;
	f[now].clear();
	for(int i=1; i<=f[lst].siz; ++i) {
		int sta=f[lst].key[i];
		int val=f[lst].val[i];
		int p1=find(sta,y);
		int p2=find(sta,y+1);
		if(sta>=(1<<((m+1)<<1))) continue;
		if(a[x][y]!='_') {
			if(!p1&&!p2) upd(val);
		} else if(!p1&&!p2) {
			if(a[x+1][y]=='_') set(sta,y,1),set(sta,y+1,0),upd(val);
			if(a[x][y+1]=='_') set(sta,y,0),set(sta,y+1,1),upd(val);
			if(a[x][y+1]=='_'&&a[x+1][y]=='_') set(sta,y,2),set(sta,y+1,2),upd(val); 
		} else if(!p1&&p2) {
			if(p2==1) {
				if(a[x+1][y]=='_') set(sta,y,p2),set(sta,y+1,0),upd(val);
				if(a[x][y+1]=='_') set(sta,y,0),set(sta,y+1,2),upd(val);
			} else {
				set(sta,y,0),set(sta,y+1,0),upd(val);
				if(x==edx&&y==edy&&!sta) ans=(ans+val)%mod;
				if(a[x+1][y]=='_') set(sta,y,2),upd(val); 
			}
		} else if(p1&&!p2) {
			if(p1==1) {
				if(a[x][y+1]=='_') set(sta,y,0),set(sta,y+1,1),upd(val);
				if(a[x+1][y]=='_') set(sta,y,2),set(sta,y+1,0),upd(val);
			} else {
				set(sta,y,0),set(sta,y+1,0),upd(val);
				if(x==edx&&y==edy&&!sta) ans=(ans+val)%mod;
				if(a[x][y+1]=='_') set(sta,y+1,2),upd(val);
			}
		} else if(p1==1&&p2==1) {
			set(sta,y,0),set(sta,y+1,0),upd(val);
			if(x==edx&&y==edy&&!sta) ans=(ans+val)%mod;
		} //其余情况不合法 
	} 
}

int main() {
	scanf("%d%d",&n,&m);
	for(int i=1; i<=n; ++i) {
		scanf("%s",a[i]+1);
	} 
	if(n<m) { //转置 
		for(int i=1; i<=n; ++i) {
			for(int j=i+1; j<=m; ++j) 
				std::swap(a[i][j],a[j][i]);
		}
		std::swap(n,m);
	}
	for(int i=1; i<=n; ++i) {
		for(int j=1; j<=m; ++j) {
			if(a[i][j]!='*') a[i][j]='_';
			if(a[i][j]=='_') edx=i,edy=j; 
		}
	}
	f[0].clear();
	f[0][0]=1;
	for(int i=1; i<=n; ++i) {
		for(int j=1; j<=m; ++j) p_dp(i,j);
		if(i!=n) {	
			int now=(i*m)&1; 
			for(int j=1; j<=f[now].siz; ++j) 
				f[now].key[j]<<=2;
		} 
	}
	printf("%d",ans);
	return 0;
}

luogu3886 [JLOI2009]神秘的生物

很裸的最小表示法的题目。

#include <bits/stdc++.h>
//using namespace std;
const int inf=0x3f3f3f3f;

struct hash_map {
	static const int P=23333;
	int siz,hsh[P],key[P],val[P];
	void clear() {
		siz=0;
		memset(hsh,0,sizeof hsh);
		memset(key,-1,sizeof key);
		memset(val,-inf,sizeof val);
	}
	void new_hsh(int id,int sta) {
		hsh[id]=++siz,key[siz]=sta;
	}
	int&operator[](int sta) {
		for(int i=sta%P; ; i=(i+1==P?0:i+1)) {
			if(!hsh[i]) new_hsh(i,sta);
			if(key[hsh[i]]==sta) return val[hsh[i]];
		}
	}
} f[2];

int n,ans=-inf;
int now,lst=1,a[10][10];
int find(int sta,int bit) {
	if(!bit) return 0;
	return (sta>>(3*(bit-1)))&7;
} 
void set(int&sta,int bit,int val) {
	bit=3*(bit-1);
	sta|=7<<bit;
	sta^=7<<bit;
	sta|=val<<bit;
}
int count(int sta,int val) {
	int c=0;
	for(int i=1; i<=n; ++i,sta>>=3) 
		if((sta&7)==val) c++;
	return c;
}
int relabel(int sta) {
	static int hsh,cnt,id[10],w[10];
	memset(id,-1,sizeof id);
	hsh=cnt=id[0]=0;
	for(int i=1; i<=n; ++i,sta>>=3) w[i]=sta&7;
	for(int i=n; i; --i) {
		if(id[w[i]]==-1) id[w[i]]=++cnt;
		hsh=hsh<<3|id[w[i]];
	} 
	return hsh;
}
bool unicom(int sta) {
	bool exi=0;
	for(int i=1; i<=n; ++i,sta>>=3) {
		if((sta&7)>1) return 0;
		if((sta&7)==1) exi=1;
	}
	return exi;
}

#define upd(sta,val) f[now][sta]=std::max(f[now][sta],(val))

void p_dp(int x,int y) {
	now=lst,lst^=1;
	f[now].clear();
	for(int i=1; i<=f[lst].siz; ++i) {
		int sta=f[lst].key[i];
		int val=f[lst].val[i];
		int p1=find(sta,y-1);
		int p2=find(sta,y);
		if(!p1&&!p2) {
			upd(sta,val);
			set(sta,y,7);
			upd(relabel(sta),val+a[x][y]);
		} else if(!p1&&p2) {
			if(count(sta,p2)==1) upd(sta,val+a[x][y]);
			else {
				upd(sta,val+a[x][y]);
				set(sta,y,0);
				upd(relabel(sta),val);
			}
		} else if(p1&&!p2) {
			upd(sta,val);
			set(sta,y,p1);
			upd(relabel(sta),val+a[x][y]);
		} else if(p1==p2) {
			upd(sta,val+a[x][y]);
			set(sta,y,0);
			upd(relabel(sta),val);
		} else {
			if(count(sta,p2)==1) {
				for(int j=1,tmp=sta; j<=n; ++j,tmp>>=3) 
					if((tmp&7)==p1) set(sta,j,p2);
				upd(relabel(sta),val+a[x][y]);
			} else {
				int tmp=sta;
				set(tmp,y,0);
				upd(relabel(tmp),val);
				tmp=sta;
				for(int j=1; j<=n; ++j,tmp>>=3) 
					if((tmp&7)==p1) set(sta,j,p2);
				upd(relabel(sta),val+a[x][y]);
			}
		}
	}
	for(int i=1; i<=f[now].siz; ++i) {
		if(unicom(f[now].key[i])) ans=std::max(ans,f[now].val[i]);
	}
}

int main() {
	scanf("%d",&n);
	f[now].clear();
	f[now][0]=0;
	for(int i=1; i<=n; ++i) {
		for(int j=1; j<=n; ++j) {
			scanf("%d",&a[i][j]);
			p_dp(i,j);
		}
	}
	printf("%d\n",ans);
}


bzoj2595 [WC2008]游览计划

斯坦纳树的解法详见最小斯坦纳树初探
插头dp(最小表示法)详见QAQ

bzoj1494 [NOI2007]生成树计数

假设已经考虑了前i-1个点,此时轮廓线定义为i-k到i-1的连通性(状态设为\(f[i-1,s]\))。k很小,搜索可知连续k位的连通性表示(最小表示法)不会超过55个。而且在i>=k时的状态的转移显然可以矩乘优化,只需要处理考虑i=k时前k位的连通性\(s\)的方案数。

可以参考论文

#include <bits/stdc++.h>
using namespace std;
const int N=55;
const int mod=65521;

int k,cnt,expr[N][6];
long long n;

struct mtr {
	int a[N][N];
	int*operator[](int x) {return a[x];}
	mtr operator*(mtr b) {
		static mtr c;
		memset(&c,0,sizeof c);
		for(int i=1; i<=cnt; ++i) {
			for(int k=1; k<=cnt; ++k) {
				for(int j=1; j<=cnt; ++j) {
					c[i][j]=(c[i][j]+1u*a[i][k]*b[k][j])%mod;
				}
			}
		} 
		return c;
	}
	mtr pow(long long y) {
		static mtr x,c;
		x=*this;
		memset(&c,0,sizeof c);
		for(int i=1; i<=cnt; ++i) c[i][i]=1;
		for(; y; y>>=1,x=x*x) if(y&1) c=c*x;
		return c; 
	}
} ans,A,B;

int id[8000],t[7],tmp[7],vis[7],cpl[6];
int qpow(int x,int y) {
	int c=1;
	for(; y>0; y>>=1,x=1u*x*x%mod) 
		if(y&1) c=1u*c*x%mod;
	return c;
}
void dfs(int dep,int mx) {
	if(dep==k+1) {
		cnt++;
		int hs=0;
		memset(vis,0,sizeof vis);
		for(int i=1; i<=k; ++i) {
			expr[cnt][i]=t[i];
			hs=hs*6+t[i];
			vis[t[i]]++;
		}
		id[hs]=cnt;
		B[cnt][1]=1;
		for(int i=1; vis[i]; ++i) {
			B[cnt][1]=1u*B[cnt][1]*cpl[vis[i]]%mod;
		}
		return;
	}
	for(int i=1; i<=mx; ++i) {
		t[dep]=i,dfs(dep+1,mx);
	}
	t[dep]=mx+1;
	dfs(dep+1,mx+1);
}
void init() {
	for(int i=1; i<=k; ++i) cpl[i]=qpow(i,i-2);
	dfs(1,0);
	for(int i=1; i<=cnt; ++i) {
		copy(expr[i]+1,expr[i]+k+1,t), t[k]=6;
		copy(t,t+k+1,tmp);
		for(int j=0; j<(1<<k); ++j) {
			bool ok=1;
			for(int p=0; p<k; ++p) if((j>>p)&1) {
				int c=t[p];
				if(c==6) { ok=0; break;}
				for(int q=0; q<k; ++q) if(t[q]==c) t[q]=6;
			}
			if(ok) {
				int tot=0,hs=0;
				memset(vis,0,sizeof vis);
				for(int p=1; p<=k; ++p) {
					if(!vis[t[p]]) vis[t[p]]=++tot;
					hs=hs*6+(t[p]=vis[t[p]]);
				}
				if(vis[t[0]]) (A[id[hs]][i]+=1)%=mod;
			}
			copy(tmp,tmp+k+1,t);
		}
	}
}

int main() {
	scanf("%d%lld",&k,&n);
	if(n<=k) {
		printf("%d\n",qpow(n,n-2));
		return 0;
	}
	init();
	ans=A.pow(n-k)*B;
	printf("%d\n",ans[1][1]);
	return 0;
}
posted @ 2019-02-22 10:33  nosta  阅读(293)  评论(0编辑  收藏  举报