BZOJ4824 [Cqoi2017]老C的键盘 【树形dp】

题目链接

BZOJ4824

题解

观察出题目中的关系实际上是完全二叉树的父子关系

我们设\(f[i][j]\)为以\(i\)为根的节点在其子树中排名为\(j\)的方案数
转移时,枚举左右子树分别有几个节点比\(i\)小,进行转移

乍一看是\(O(n^3)\)的,但其复杂度分析和某一题很像
就是在根处枚举两个子树大小,实质上就等于枚举任意两点\(lca\),是\(O(n^2)\)

#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<map>
#define Redge(u) for (int k = h[u],to; k; k = ed[k].nxt)
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define mp(a,b) make_pair<int,int>(a,b)
#define cls(s) memset(s,0,sizeof(s))
#define cp pair<int,int>
#define LL long long int
#define ls (u << 1)
#define rs (u << 1 | 1)
using namespace std;
const int maxn = 205,maxm = 100005,INF = 1000000000,P = 1000000007;
inline int read(){
	int out = 0,flag = 1; char c = getchar();
	while (c < 48 || c > 57){if (c == '-') flag = -1; c = getchar();}
	while (c >= 48 && c <= 57){out = (out << 3) + (out << 1) + c - 48; c = getchar();}
	return out * flag;
}
LL C[maxn][maxn],f[maxn][maxn],siz[maxn],n,typ[maxn];
LL suml[maxn][maxn],sumr[maxn][maxn];
char s[maxn];
void init(){
	for (int i = 0; i <= 100; i++){
		C[i][0] = C[i][i] = 1;
		for (int j = 1; j <= (i >> 1); j++)
			C[i][j] = C[i][i - j] = (C[i - 1][j - 1] + C[i - 1][j]) % P;
	}
}
void dfs(int u){
	if (u > n) return;
	dfs(ls); dfs(rs);
	siz[u] = siz[ls] + 1 + siz[rs];
	if (ls > n) f[u][1] = 1;
	else if (rs > n){
		for (int i = 0; i <= siz[ls]; i++)
			if (!typ[ls]) f[u][i + 1] = sumr[ls][i + 1];
			else f[u][i + 1] = suml[ls][i];
	}
	else {
		LL t1,t2;
		for (int i = 0; i <= siz[ls]; i++){
			if (!typ[ls]) t1 = sumr[ls][i + 1];
			else t1 = suml[ls][i];
			for (int j = 0; j <= siz[rs]; j++){
				if (!typ[rs]) t2 = sumr[rs][j + 1];
				else t2 = suml[rs][j];
				f[u][i + j + 1] = (f[u][i + j + 1] + C[i + j][i] * C[siz[u] - (i + j + 1)][siz[ls] - i] % P * t1 % P * t2 % P) % P;
			}
		}
	}
	for (int i = 1; i <= n ; i++) suml[u][i] = (suml[u][i - 1] + f[u][i]) % P;
	for (int i = n; i >= 0; i--) sumr[u][i] = (sumr[u][i + 1] + f[u][i]) % P;
}
int main(){
	init();
	n = read();
	scanf("%s",s + 2);
	for (int i = 2; i <= n; i++)
		typ[i] = s[i] == '<' ? 0 : 1;
	dfs(1);
	printf("%lld\n",suml[1][n]);
	return 0;
}

posted @ 2018-05-11 15:36  Mychael  阅读(202)  评论(0编辑  收藏  举报