CF1594 E2. Rubik's Cube Coloring (hard version)

做出来了但是没交


题意

白色节点不能与白色和黄色节点相邻;
黄色节点不能与白色和黄色节点相邻;
绿色节点不能与绿色和蓝色节点相邻;
蓝色节点不能与绿色和蓝色节点相邻;
红色节点不能与红色和橙色节点相邻;
橙色节点不能与红色和橙色节点相邻;

据此给一个高为n(\(n<60\))的完全二叉树染色。求方案数:

不难推出公式: 设\(fi\)表示以高度为i的树,确定根节点颜色,有多少方案, \(f_i=f_i^2*16\)(根节点确定, 每个子节点有4种可行颜色)

E2在E1的基础上, 有M个点的的颜色已经给出, 求给其他点染色的方案数。


题解

我们先把E1的\(f_{i, j}\)数组预处理出来(高度为i的树没有预给定颜色, 且根节点颜色为j的方案数)

考虑直接dp:如果这个节点的子树中没有预定义节点, 直接返回f数组, 否则直接暴力dp

此时需要处理的节点个数为子树中包含预定义节点的节点个数, 可以证明这样的点数量为\(O(nm)\)(每个预定义点都在叶子节点最大)

现在的问题是如何判断一个点是否包含预定义节点呢:
一种暴力的做法是直接枚举所有给出的点, 求出高度进行判断。但是这样的复杂度多了nm的常数
考虑反过来, 对于每个预定义点, 把它的所有祖先的标记一下, 不难想到, 让节点编号每次除2就行。需要一个set维护, 同时还需要一个map来记录颜色。

总复杂度\(O(nm\log(nm))\)以及一个较大的常数


实现

赛时代码过丑抱歉

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <string>
#include <cassert>
#include <cstdlib>
#include <map>
#include <set>
#define int long long
using namespace std;

int read(){
	int num=0, flag=1; char c=getchar();
	while(!isdigit(c) && c!='-') c=getchar();
	if(c=='-') flag=-1, c=getchar();
	while(isdigit(c)) num=1ll*num*10+c-'0', c=getchar();
	return num*flag; 
}

int readc(){
	char c=getchar();
	while(c<'a' || c>'z') c=getchar();
	return c;
}

void reads(){
	char c=getchar();
	while(c<'a' || c>'z') c=getchar();
	while(c>='a' && c<='z') c=getchar();
}

int checkColor(int a, int b){
	if(a > b) swap(a, b);
	if(a == b) return 1;
	if(a==4 && b==5) return 1;
	if(a==2 && b==3) return 1;
	if(a==0 && b==1) return 1;
	return 0;
}

const int N = 3055, M = 5e6;
const int mod = 1e9+7;
int T, n, m;
int r[M][6], sz=0;
set<int> st;
map<int, int> col;

int get(char s){
	if(s == 'w') return 0;
	if(s == 'y') return 1;
	if(s == 'g') return 2;
	if(s == 'b') return 3;
	if(s == 'r') return 4;
	return 5; 
}

int p[N], c;
int res[N][6];

int getlas(int id){
	int num = 0;
	while(id) num++, id>>=1;
	return n-num+1;
}

int cnt = 0;

int check(int id){
	return st.count(id);
}

int dp(int id, int h) {
	if(h == 1){
		
		int lres = (++sz)%M;
		if(col.count(id)) {
			for(int i=0; i<6; i++) r[lres][i] = 0;
			r[lres][col[id]] = 1;
			return lres;
		}else{
			for(int i=0; i<6; i++) r[lres][i] = 1;
			return lres;
		}
	}
	
	if(check(id)){
		int r1=dp(id*2, h-1), r2=dp(id*2+1, h-1);
		
		int lres = (++sz)%M;
		if(col.count(id)) {
			int i = col[id];
			for(int j=0; j<6; j++) r[lres][j]=0;
			
			for(int j=0; j<6; j++){
					for(int k=0; k<6; k++){
						if(checkColor(i, j) || checkColor(i, k)) continue;
						r[lres][i] += r[r1][j]*r[r2][k]%mod;
						r[lres][i] %= mod;
					}
			}
			
			return lres;
		}else{
			for(int j=0; j<6; j++) r[lres][j]=0;
			for(int i=0; i<6; i++){
				for(int j=0; j<6; j++){
					for(int k=0; k<6; k++){
						if(checkColor(i, j) || checkColor(i, k)) continue;
						r[lres][i] += r[r1][j]*r[r2][k]%mod;
						r[lres][i] %= mod;
					}
				}
			}
			return lres;
		}
		
	}else{
		int lres = (++sz)%M;
		for(int i=0; i<6; i++){
			r[lres][i] = res[h][i];
		}
		return lres;
	}
}

signed main(){
	n=read(), m=read();
	for(int i=1; i<=m; i++){
		p[i]=read();
		c=get(readc()); reads();
		
		int cur = p[i];
		while(cur) {
			if(!st.count(cur)) st.insert(cur);
			cur /= 2; 
		} 
		col[p[i]] = c;
	}
	
	int num = 1;
	for(int i=1; i<=n; i++){
		for(int j=0; j<6; j++) res[i][j] = num;
		num = (num*num%mod*16%mod); 
	}
	
	int x = dp(1, n);
	int ans = 0;
	for(int i=0; i<6; i++) ans = (ans + r[x][i]) % mod;
	cout << ans << endl;
	
	return 0;
}

posted @ 2021-10-09 10:07  ltdJcoder  阅读(101)  评论(0编辑  收藏  举报