[SCOI2003]严格N元树

嘟嘟嘟


这题自己搞了个dp,然后卡空间一上午终于过了……


\(dp[i][j]\)表示深度为\(i\),这一层有\(n * j\)个节点的\(n\)元树的个数。然后我们枚举上一层的节点个数进行转移:$$dp[i][j] = \sum _ {k = 1} ^ {n ^ {i - 2}}dp[i - 1][k] * C_{nk} ^ {j}$$
然后因为出题人要恶心人没有取模,所以要写高精。
写完后自信满满一交MLE了,发现只有128MB……于是开始卡空间……
首先把dp变成滚动数组,但还是超时。这时候想到了压位高精,于是把每8位压成1位,然后高精数组长度改成\(\frac{200}{8} + 1 = 26\),这样就过了……多开一点就会MLE。


其实我这不是最简单的写法,最简单的写法是令\(dp[i]\)表示深度不超过\(i\)\(n\)元树的个数,这样答案就是\(dp[d] - dp[d - 1]\)。转移的时候考虑新添加一个根连向\(i - 1\)的所有树根,于是有\(dp[i] = dp[i - 1] ^ n + 1\)\(+1\)是因为没有儿子节点的情况。

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
#include<assert.h>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 35;
const int N = 1024;
In ll read()
{
  ll ans = 0;
  char ch = getchar(), last = ' ';
  while(!isdigit(ch)) last = ch, ch = getchar();
  while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
  if(last == '-') ans = -ans;
  return ans;
}
In void write(ll x)
{
  if(x < 0) x = -x, putchar('-');
  if(x >= 10) write(x / 10);
  putchar(x % 10 + '0');
}
In void MYFILE()
{
#ifndef mrclr
  freopen("ha.in", "r", stdin);
  freopen("ha.out", "w", stdout);
#endif
}

int n, d;
const int maxN = 26;
const int M = 100000000;
const int P = 8;
struct Big
{
  int a[maxN], len;
  In void N() {a[0] = len = 1;}
  friend In Big operator + (const Big& A, const Big& B)
  {
    Big ret; 
    int Len = min(maxN - 1, max(A.len, B.len));
    if(Len == maxN - 1) {ret.a[0] = 0, ret.len = 1; return ret;}
    Mem(ret.a, 0);
    for(int i = 0; i < Len; ++i)
      {
	ret.a[i] += A.a[i] + B.a[i];
	ret.a[i + 1] = ret.a[i] / M, ret.a[i] %= M;
      }
    if(Len < maxN - 1 && ret.a[Len]) ++Len;
    while(Len > 1 && !ret.a[Len - 1]) --Len;
    ret.len = Len;
    return ret;
  }
  friend In Big operator * (const Big& A, const Big& B)
  {
    Big ret;
    ret.len = min(maxN - 1, A.len + B.len - 1);
    if(ret.len == maxN - 1) {ret.a[0] = 0, ret.len = 1; return ret;}
    Mem(ret.a, 0);
    for(int i = 0; i < A.len; ++i)
      for(int j = 0; j < B.len; ++j)
	{
	  ll tp = 1LL * A.a[i] * B.a[j];
	  if(i + j < maxN) ret.a[i + j] += tp % M;
	  if(i + j + 1 < maxN) ret.a[i + j + 1] += tp / M;
	}
    if(ret.len < maxN - 1 && ret.a[ret.len]) ++ret.len;
    while(ret.len > 1 && !ret.a[ret.len - 1]) --ret.len;
    return ret;
  }
  In void out()
  {
    write(a[len - 1]);
    for(int i = len - 2; i >= 0; --i) printf("%0*d", P, a[i]);
  }
}C[N + 2][N + 2], dp[2][N + 2];

ll Pow[N + 2];
In void init()
{
  for(int i = 0; i <= N; ++i)
    for(int j = 0; j <= N; ++j) C[i][j].a[0] = 0, C[i][j].len = 1;
  C[0][0].N();
  for(int i = 1; i <= N; ++i)
    {
      C[i][0].N();
      for(int j = 1; j <= i; ++j) C[i][j] = C[i - 1][j] + C[i - 1][j - 1];
    }
  Pow[0] = 1;
  for(int i = 1; i <= N; ++i) Pow[i] = Pow[i - 1] * n;
}

int main()
{
  MYFILE();
  n = read(), d = read();
  if(!d) {puts("1"); return 0;}
  init();
  dp[0][1].N(), dp[1][1].N();
  int o = 0;
  for(int i = 2; i <= d; ++i, o ^= 1)
    {
      for(int j = 1; j <= Pow[i - 1]; ++j)
	{
	  Mem(dp[o][j].a, 0); dp[o][j].len = 1;
	  for(int k = 1; k <= Pow[i - 2]; ++k)
	    dp[o][j] = dp[o][j] + (dp[o ^ 1][k] * C[n * k][j]);
	}
    }
  Big ans; Mem(ans.a, 0), ans.len = 1;
  for(int i = 1; i <= Pow[d - 1]; ++i) ans = ans + dp[o ^ 1][i];
  ans.out(), enter;
  return 0;
}
posted @ 2019-06-02 15:20  mrclr  阅读(186)  评论(0编辑  收藏  举报