bzoj1494
这道题不仅得看俞华程的论文,还得看陈丹琦的论文,否则是不可能做出来的。因为难点在构造矩阵上。
构造矩阵困难在如何表示状态,因为树不能有环,也不能不连通,这里我们引入了最小表示法来表示连续k个点的连通性。
首先我们找出所有可能的状态,dfs一下就行了,最多只有53种。然后计算每种状态的形态,状态只是表示了连通性,但没有表示之间的形态。于是我们初始每种状态形态的数量作为列向量。然后就是构造转移矩阵。这个转移矩阵表示一个状态能够转移到另一个状态,其实是每次向前移动一位。每次向前移动一位也就是说对于一个状态我们要找出所有可以成为这个状态后移一位的合法的状态。每两个状态之间系数矩阵上的值为可能的形态数,这里的形态数和刚才不太一样,向后移一位说明把当前k个点去掉了第一个点,然后又添加了一个点。这里我们用二进制枚举连通性,也就是说新加入的点和之前k个点中哪些点是联通的。那么这样的连通性会有很多情况,比如说原来的最小表示是001,二进制枚举出来的是11,那么新的点既要和第一个联通块联通,也要和第二个联通块联通,也就是有两种情况,即是(设这个点为4,之前为123)(4->1, 4->3) (4->2,4->3)两种联通情况。
最后统计答案是这样做的,因为最后只有一种合法状态,即0....0,必须所有都联通,所以sigma(f[i][1]*ret[i][1]),i->1表示所有能转移到1状态的情况。
然后就可以矩阵快速幂了。。。
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int N = 60, M = 50010, mod = 65521; struct mat { ll a[N][N]; } A, f; ll k, n; int mir[N], cnt[N], p[N], bit[10], vis[N]; ll power(ll x, ll t) { ll ret = 1; for(; t; t >>= 1, x = x * x % mod) if(t & 1) ret = ret * x % mod; return ret; } void collect(int x) { int cnt[10]; memset(cnt, 0, sizeof(cnt)); f.a[p[0]][1] = 1; mir[x] = p[0]; for(int i = 1; i <= k; ++i) ++cnt[x % 10], x /= 10; for(int i = 0; i < k; ++i) if(cnt[i] > 1) f.a[p[0]][1] = f.a[p[0]][1] * power(cnt[i], cnt[i] - 2) % mod; } void dfs(int num, int d, int bound, int x) { if(d == k) { p[++p[0]] = num; collect(p[p[0]]); return; } for(int i = 0; i <= bound; ++i) dfs(num + i * x, d + 1, max(i + 1, bound), x / 10); } void Init(int pos) { memset(bit, 0, sizeof(bit)); int x = p[pos], maxn = -1; bool flag = true; for(int j = 0, l = x; j < k; ++j) bit[k - j - 1] = l % 10, l /= 10; for(int i = 1; i < k; ++i) { if(bit[i] == bit[0]) flag = false; maxn = max(maxn, bit[i]); } int lim = 1 << (maxn + 1); for(int i = 0; i < lim; ++i) { //枚举连通性, 枚举和原来的每位是否联通 if(flag && !(i & 1)) continue; //不成立的情况 for(int j = 0, l = x; j < k; ++j) bit[k - j - 1] = l % 10, l /= 10; bit[k] = -1; int ans = 1, l = -1, tot = 0; ; for(int j = 0; j <= maxn; ++j) if(i & (1 << j)) { int t = 0; for(int x = 0; x < k; ++x) if(bit[x] == j) bit[x] = -1, ++t; ans = ans * t; } memset(vis, 0, sizeof(vis)); for(int j = 1; j <= k; ++j) if(!vis[j]) { int color = bit[j]; bit[j] = ++l; vis[j] = 1; for(int x = 0; x <= k; ++x) if(bit[x] == color && !vis[x]) bit[x] = l, vis[x] = 1; } for(int j = 1; j <= k; ++j) tot = tot * 10 + bit[j]; A.a[pos][mir[tot]] = ans % mod; } } mat operator * (mat A, mat B) { mat ret; memset(ret.a, 0, sizeof(ret.a)); for(int i = 1; i <= p[0]; ++i) for(int j = 1; j <= p[0]; ++j) for(int k = 1; k <= p[0]; ++k) ret.a[i][j] = (ret.a[i][j] + A.a[i][k] % mod * B.a[k][j] % mod) % mod; return ret; } int main() { scanf("%d%lld", &k, &n); if(k >= n) { printf("%lld\n", power(n, n - 2)); return 0; } int base = 1; for(int i = 1; i < k; ++i) base = base * 10; dfs(0, 0, 0, base); for(int i = 1; i <= p[0]; ++i) Init(i); mat ret; memset(ret.a, 0, sizeof(ret.a)); for(int i = 1; i <= p[0]; ++i) ret.a[i][i] = 1; for(ll t = n - k; t; t >>= 1, A = A * A) if(t & 1) ret = ret * A; ll ans = 0; for(int i = 1; i <= p[0]; ++i) ans = (ans + f.a[i][1] * ret.a[i][1]) % mod; printf("%lld\n", ans); return 0; }