题解 [LNOI2022] 题

传送门

全世界就 tm 我不会签到题

考虑一个 \(O(n^7)\) 的 DP
合法的排列就 3 个
存有几个 1, 2, 3, 21, 32, 13 就行了
注意精细处理上界

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ll long long
//#define int long long

int n;
char s[N];
const ll mod=1e9+7;
const ll tab[]={0, 3, 180, 45360, 29937600, 864823720};

namespace force{
	char t[N];
	ll fac[N], ans;
	const int nxt[][2]={{2,1},{0,2},{1,0}};
	unordered_map<int, ll> mp[40];
	inline int encode(int n, char* sta) {int ans=0; for (int i=1; i<=n; ++i) ans=ans*3+sta[i]; return ans;}
	inline void decode(char* sta, int n, int s) {for (int i=n; i; --i) sta[i]=s%3, s/=3;}
	ll dfs(int n, int s) {
		if (!n) return 1;
		if (mp[n].find(s)!=mp[n].end()) return mp[n][s];
		ll ans=0;
		char sta[n+1], tem[n+1];
		decode(sta, n, s);
		for (int j=2; j<=n; ++j) if (sta[j]==nxt[sta[1]][0]) {
			for (int k=j+1; k<=n; ++k) if (sta[k]==nxt[sta[1]][1]) {
				int tot=0;
				for (int l=2; l<=n; ++l) if (l!=j&&l!=k) tem[++tot]=sta[l];
				ans=(ans+dfs(n-3, encode(n-3, tem)))%mod;
			}
		}
		return mp[n][s]=ans;
	}
	void dfs1(int u) {
		if (u>n) {
			int cnt[4];
			memset(cnt, 0, sizeof(cnt));
			for (int i=1; i<=n; ++i) ++cnt[t[i]];
			if (!(cnt[1]==cnt[2] && cnt[2]==cnt[3])) return ;
			// cout<<"t: "; for (int i=1; i<=n; ++i) cout<<int(t[i])<<' '; cout<<endl;
			for (int i=1; i<=n; ++i) --t[i];
			ans=(ans+dfs(n, encode(n, t)))%mod;
			for (int i=1; i<=n; ++i) ++t[i];
			return ;
		}
		if (s[u]) t[u]=s[u], dfs1(u+1);
		else for (int i=1; i<=3; ++i) t[u]=i, dfs1(u+1);
	}
	void solve() {
		fac[0]=fac[1]=1; ans=0;
		// cout<<"s: "<<s+1<<endl;
		for (int i=1; i<=n; ++i) s[i]-='0';
		for (int i=1; i<=n; ++i) fac[i]=fac[i-1]*i%mod;
		dfs1(1);
		printf("%lld\n", ans*fac[n/3]%mod);
	}
}

namespace task{
	ll fac[N];
	int dp[2][20][20][20][20][20][20], now;
	inline void md(int &a, ll b) {a=(a+b)%mod;}
	void solve() {
		fac[0]=fac[1]=1;
		dp[now][0][0][0][0][0][0]=1;
		for (int i=2; i<=n; ++i) fac[i]=fac[i-1]*i%mod;
		for (int i=1; i<=n; ++i,now^=1) {
			// memset(dp[now^1], 0, sizeof(dp[now^1]));
			for (int a=0; a<=min(i, n/3); ++a)
				for (int b=0; b<=min(i, n/3) && a+b<=i && a+b<=n/3; ++b)
					for (int c=0; c<=min(i, n/3) && a+b+c<=i && a+b+c<=n/3; ++c)
						for (int d=0; d<=min(i/2+1, n/3) && a+b+c+2*d<=i && a+b+c+d<=n/3; ++d)
							for (int e=0; e<=min(i/2+1, n/3) && a+b+c+2*d+2*e<=i && a+b+c+d+e<=n/3; ++e)
								for (int f=0; f<=min(i/2+1, n/3) && a+b+c+2*d+2*e+2*f<=i && a+b+c+d+e+f<=n/3; ++f)
									dp[now^1][a][b][c][d][e][f]=0;
			for (int a=0; a<=min(i-1, n/3); ++a) {
				for (int b=0; b<=min(i-1, n/3) && a+b<i && a+b<=n/3; ++b) {
					for (int c=0; c<=min(i-1, n/3) && a+b+c<i && a+b+c<=n/3; ++c) {
						for (int d=0; d<=min(i/2, n/3) && a+b+c+2*d<i && a+b+c+d<=n/3; ++d) {
							for (int e=0; e<=min(i/2, n/3) && a+b+c+2*d+2*e<i && a+b+c+d+e<=n/3; ++e) {
								for (int f=0; f<=min(i/2, n/3) && a+b+c+2*d+2*e+2*f<i && a+b+c+d+e+f<=n/3; ++f) if (dp[now][a][b][c][d][e][f]) {
									if (s[i]=='0'||s[i]=='1') {
										md(dp[now^1][a+1][b][c][d][e][f], (ll)dp[now][a][b][c][d][e][f]);
										if (b) md(dp[now^1][a][b-1][c][d][e+1][f], (ll)dp[now][a][b][c][d][e][f]*b);
										if (d) md(dp[now^1][a][b][c][d-1][e][f], (ll)dp[now][a][b][c][d][e][f]*d);
									}
									if (s[i]=='0'||s[i]=='2') {
										md(dp[now^1][a][b+1][c][d][e][f], (ll)dp[now][a][b][c][d][e][f]);
										if (c) md(dp[now^1][a][b][c-1][d+1][e][f], (ll)dp[now][a][b][c][d][e][f]*c);
										if (f) md(dp[now^1][a][b][c][d][e][f-1], (ll)dp[now][a][b][c][d][e][f]*f);
									}
									if (s[i]=='0'||s[i]=='3') {
										md(dp[now^1][a][b][c+1][d][e][f], (ll)dp[now][a][b][c][d][e][f]);
										if (a) md(dp[now^1][a-1][b][c][d][e][f+1], (ll)dp[now][a][b][c][d][e][f]*a);
										if (e) md(dp[now^1][a][b][c][d][e-1][f], (ll)dp[now][a][b][c][d][e][f]*e);
									}
								}
							}
						}
					}
				}
			}
		}
		printf("%lld\n", dp[now][0][0][0][0][0][0]*fac[n/3]%mod);
	}
}

signed main()
{
	// freopen("problem.in", "r", stdin);
	// freopen("problem.out", "w", stdout);

	int T;
	scanf("%d", &T);
	while (T--) {
		scanf("%d%s", &n, s+1);
		n*=3;
		bool zero=1;
		for (int i=1; i<=n; ++i) if (s[i]!='0') zero=0;
		// if (zero) printf("%lld\n", tab[n/3]);
		// else force::solve();
		task::solve();
	}

	return 0;
}
posted @ 2022-06-08 08:49  Administrator-09  阅读(2)  评论(0编辑  收藏  举报