codeforces 842C Ilya And The Tree (01背包+dfs)
题目分析
题意:在一个树中,有n个结点,记为 1~n ,其中根结点编号为1,每个结点都有一个值val[i],问从根结点到各个结点的路径中所有结点的值的gcd(最大公约数)最大是多少,其中,我们可以将路径中某一个结点的值变为0,也可以选择不变。
思路:注意到对于每个结点,我们可以选择这个结点,或者不选这个结点(将权值记为0),因而有点01背包的感觉,而我们求gcd的时候需要取所有情况中的最大值
那么我们从根结点开始,每经过一个结点,就从其父节点的所有情况转移得到当前结点的状态,而对于每个结点所含有的状态,有如下三种
1)从根结点到当前结点,没有将任一结点的值变为0的情况下得到的gcd
2)从根结点到当前结点,选取除当前结点外的所有结点得到的gcd,也就是将当前结点的值当作0
3)由其父节点的所有情况转移而来,由于父节点的情况中一定会有不选取从根结点到父节点途中某点的情况,那么我们由父结点向当前结点转移得到的gcd(当前结点的值不当作0),代表不选取根结点到当前结点途中某点的情况。
这三种情况实际上整合为两种情况,即是由父结点转移到当前点的过程中,是否选择当前结点的两种情况,不过有的情况下我们必须选取当前结点,所以为了避免这种麻烦的判断,写为三种情况。
最后,我们取每个点所有情况下的最大值即可,不过为了节省时间和空间,我们用set对每个结点的状态去重。
(吐槽:个人感觉这种写法有TLE的嫌疑,又或者这个题的数据有点水....)
代码区
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<queue> #include<string> #include<fstream> #include<vector> #include<stack> #include <map> #include <iomanip> #include<set> #include<cmath> #define bug cout << "**********" << endl #define show(x, y) cout<<"["<<x<<","<<y<<"] " #define LOCAL = 1; using namespace std; typedef long long ll; const int inf = 0x3f3f3f3f; const ll mod = 998244353; const int Max = 2e5 + 10; const int Max2 = 1e3 + 10; struct Edge { int to, next; } edge[Max << 1]; int n; int head[Max], tot; int val[Max], beauty[Max]; set<int> s[Max]; void init() { memset(head, -1, sizeof(head)); tot = 0; for (int i = 1; i <= n; i++) s[i].clear(); } void add(int u, int v) { edge[tot].to = v; edge[tot].next = head[u]; head[u] = tot++; } int gcd(int a, int b) { if (a < b) swap(a, b); if (b == 0) return a; return gcd(b, a % b); } void dfs(int u, int fa, int now) { for (auto it:s[fa]) { s[u].insert(gcd(it, val[u])); //选取当前数,更新由父结点转移来的所有情况 } s[u].insert(now); //不选当前结点 now = gcd(now,val[u]); //now代表的是没有删除任何点的情况下的gcd s[u].insert(now); //不删除任何点 for (int i = head[u]; i != -1; i = edge[i].next) { int v = edge[i].to; if (v == fa) continue; dfs(v, u, now); } } int main() { #ifdef LOCAL //freopen("input.txt", "r", stdin); //freopen("output.txt", "w", stdout); #endif while (scanf("%d", &n) != EOF) { init(); for (int i = 1; i <= n; i++) { scanf("%d", val + i); } for (int i = 1, u, v; i < n; i++) { scanf("%d%d", &u, &v); add(u, v); add(v, u); } dfs(1, 0, 0); for (int i = 1; i < n; i++) { printf("%d ", *s[i].rbegin()); } printf("%d\n",*s[n].rbegin()); } return 0; }