CEOI2020 星际迷航
Statement
有 \(d+1\) 棵点数为 \(n\) 的一样的树编号为 \([0,d]\)。
对所有 \(i \in [0,d-1]\) 你将选择第 \(i\) 棵树和第 \(i+1\) 棵树上任意两个节点,并连一条边。
连完边后,一个人在第 \(0\) 棵树的 \(1\) 号节点放一个棋子和你开始玩游戏,你先手,玩家每次可以向与这个节点连边的点移动,不能移动到曾经移动过的点,不能移动者输。
你可以随便选择连边的方案数,求你可以胜利的连边的方案数。
对于 \(100 \%\) 的数据满足,\(n \le 10^5\),\(d \le 10^{18}\)。
Solution
将第 \(i\) 棵树连出去的点作为第 \(i+1\) 棵树的根,也即对于除了第 \(0\) 棵树之外的树,树根都不固定。
对于任意一棵树的叶子节点,无论谁从树根走到这里都会输掉游戏,我们称叶子节点为必输节点。
并且对于点 \(u\),如果它有一个儿子为必输节点,则 \(u\) 为必胜节点,因为此时可以选择走向必输节点,那么下一个人必输。
否则,如果没有一个儿子为必输节点,则 \(u\) 为必输节点。
第一棵树的根是固定的,所以对于第一棵树每个点的情况可以方便的求出。
考虑连边 \((u,v)\) 带来的影响,会给儿子增加一个必输节点 \(v\) 或者必胜节点 \(v\),那么连边出去的节点的值也会因此发生变化。
如果 \(u\) 是必胜节点,那么无论怎么连接,都不会改变状态。
如果 \(u\) 是必输节点,如果增加了一个必输节点,那么他将会变成必胜节点。
设 \(1\) 为第 \(0\) 棵树的根,暴力根据这个进行树上 \(\texttt{dp}\) 之后判断 \(dp_1\) 是否为 \(1\) 即可。
\(d=1\) 时,拥有 \(n^3\) 的无脑暴力,可以得到 \(8\) 分。
考虑优化过程,对于必胜节点,不会带来影响,这部分的贡献很好计算,就是 \(n \times cnt\),\(cnt\) 为必胜节点个数,所以只用关注必输节点。
根节点的值在必胜节点 \(u\) 连接到一个必输节点 \(v\) 之后,不会有变化。
根节点的值在必输节点 \(u\) 连接到一个必输节点 \(v\) 之后,\(u\) 将会变成必胜节点,可能会发生变化。
\(u\) 变化之后想要对根造成影响的充要条件是,\(fa_u\) 到根的链上每个点的儿子都为必胜节点。
那么我们执行一遍换根 \(\texttt{dp}\) ,求出每个点作为根时是否会成为必输节点,然后这些点是可以作为被连接的点 \(v\) 的。
其次,考虑哪些点可以作为合法的 \(u\),遍历一遍即可。
这样在 \(O(n)\) 的时间复杂度内做完了 \(d=1\) 的情况。
答案如下:
- 如果根节点为必输节点,答案加上必胜节点个数乘 \(n\)。
- 答案加上必输节点乘合法的必输节点个数。
考虑拓展。
考虑加入一棵树之后贡献答案方式不变,\(k\) 又很大,写个矩阵快速幂,完了。
// 德丽莎你好可爱德丽莎你好可爱德丽莎你好可爱德丽莎你好可爱德丽莎你好可爱
// 德丽莎的可爱在于德丽莎很可爱,德丽莎为什么很可爱呢,这是因为德丽莎很可爱!
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define FOR(i, l, r) for(int i = (l); i <= r; ++i)
#define REP(i, l, r) for(int i = (l); i < r; ++i)
#define DFR(i, l, r) for(int i = (l); i >= r; --i)
#define DRP(i, l, r) for(int i = (l); i > r; --i)
#define FORV(i, ver) for(int i = 0; i < ver.size(); i++)
#define FORP(i, ver) for(auto i : ver)
#define Fedge(i, x) for (int i = first[x]; i; i = nex[i])
template<class T>T wmin(const T &a, const T &b) {return a < b ? a : b;}
template<class T>T wmax(const T &a, const T &b) {return a > b ? a : b;}
template<class T>bool chkmin(T &a, const T &b) {return a > b ? a = b, 1 : 0;}
template<class T>bool chkmax(T &a, const T &b) {return a < b ? a = b, 1 : 0;}
inline int read() {
int x = 0, f = 1; char ch = getchar();
while( !isdigit(ch) ) { if(ch == '-') f = -1; ch = getchar(); }
while( isdigit(ch) ) { x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar(); }
return x * f;
}
inline void write(int x) {
if (x < 0) putchar('-'), x = -x;
if (x > 9) write(x / 10);
putchar(x % 10 + '0');
return ;
}
const int N = 5e6, mod = 1e9 + 7;
int n, d, num = 1, nex[N], first[N], v[N];
void add(int from,int to) {
nex[++num] = first[from]; first[from] = num; v[num] = to;
}
void ins(int x,int y) { add(x, y), add(y, x); }
int tx, ty, cnt, dep[N];
int dp[N], s[N][2], r[N], s0[N], dp2[N], r2[N];
void dfs(int u,int fa) {
Fedge (i, u) {
int to = v[i];
if (to == fa) continue; dfs(to, u);
s0[u] += !dp[to];
s[u][ dp[to] ] += r[to];
}
dp[u] = (s0[u] > 0);
if (s0[u] == 1) r[u] = s[u][0];
else if (s0[u] == 0) r[u] = s[u][1] + 1;
}
void dfs2(int u,int fa) {
if (!dp[u]) ++cnt; r2[u] = r[u]; dp2[u] = dp[u];
Fedge (i, u) {
int to = v[i];
if (to == fa) continue;
int s0u = s0[u], dpu = dp[u], ru = r[u], su0 = s[u][0], su1 = s[u][1];
int s0v = s0[to], dpv = dp[to], rv = r[to], sv0 = s[to][0], sv1 = s[to][1];
s0[u] -= !dp[to]; dp[u] = (s0[u] > 0); s[u][dp[to]] -= r[to];
if (s0[u] == 1) r[u] = s[u][0];
else if (s0[u] == 0) r[u] = s[u][1] + 1;
else r[u] = 0;
dp[to] |= !dp[u]; s0[to] += !dp[u]; s[to][dp[u]] += r[u];
if (s0[to] == 1) r[to] = s[to][0];
else if(s0[to] == 0) r[to] = s[to][1] + 1;
else r[to] = 0;
dfs2(to, u);
s0[u] = s0u; dp[u] = dpu; r[u] = ru; s[u][0] = su0; s[u][1] = su1;
s0[to] = s0v; dp[to] = dpv; r[to] = rv; s[to][0] = sv0; s[to][1] = sv1;
}
}
int m = 3;
struct node {
int a[5][5];
void clear() { memset(a, 0, sizeof(a)); }
node operator * (node const &b) const {
node res; res.clear();
FOR (i, 0, m) {
FOR (k, 0, m) {
FOR (j, 0, m) {
res.a[i][j] += a[i][k] * b.a[k][j] % mod;
res.a[i][j] %= mod;
}
}
}
return res;
}
node operator + (node const &b) const {
node res; res.clear();
FOR (i, 0, m) {
FOR (j, 0, m) {
res.a[i][j] += a[i][j] + b.a[i][j];
res.a[i][j] %= mod;
}
}
return res;
}
}a, b;
int tag[N];
node ksm (node x, int k) {
node ans; ans.clear();
FOR (i, 0, m) ans.a[i][i] = 1;
while (k) {
if (k & 1) ans = ans * x;
x = x * x; k >>= 1;
}
return ans;
}
signed main () {
n = read(), d = read();
REP (i, 1, n) { int x = read(), y = read(); ins(x, y); }
dfs(1, 0); dfs2(1, 0);
if (d == 1) {
if (dp2[1] == 1) {
int ans = (n - r2[1]) * cnt % mod + (n - cnt) * n % mod; ans %= mod; cout << ans << "\n";
} else cout << r2[1] * cnt % mod; return 0;
}
FOR (i, 1, n) {
if (!dp2[i]) {
b.a[0][0] += n - r2[i]; b.a[0][0] %= mod;
b.a[1][0] += n; b.a[1][0] %= mod;
b.a[0][1] += r2[i]; b.a[0][1] %= mod;
} else if (dp2[i] == 1) {
b.a[0][0] += r2[i]; b.a[0][0] %= mod;
b.a[0][1] += n - r2[i]; b.a[0][1] %= mod;
b.a[1][1] += n; b.a[1][1] %= mod;
}
}
a.a[0][0] = cnt, a.a[0][1] = n - cnt;
b = ksm(b, d - 1);
a = a * b;
int f0 = a.a[0][0], f1 = a.a[0][1];
if (dp[1]) {
int ans = (n - r2[1]) * f0 % mod + f1 * n % mod; ans %= mod; cout << ans << "\n";
} else cout << r2[1] * f0 % mod;
return 0;
}