编译 树形DP
[Description]
山山是 2017 级信奥班的成员,因为喜欢玩 Android 系统而出名。
山山写出了一个伟大的 C++工程,一共包含 N 个源文件。在山山的脑海中,N 个源文
件构成一个树形结构。每一个源文件是树上的一个节点,其中 1 号节点是树根。
现在,山山开始编译这个工程。每次他会从树上选择一条链(包含两个端点)进行编译。
由于编译器的特性,要求这条链的一个端点必须是另一个端点的祖先。一条链可以退化成一
个点。每个源文件都需要被编译恰好一次。
每一个源文件都有一个两位十六进制数的标志值(范围从 00 到 ff)。对于每一条选择的
链,把该上面所有源文件的标志值异或起来,得到这条链的特征值。把所有选择的链的特征
值相加,得到这次编译的代价。现在山山想知道至少选择几条链才能编译所有文件。在选择
的链数目最小的时候,编译的代价最小是多少。
[Input]
第一行一个整数 N。
以下一行,N 个两位十六进制数,表示第 1 号源文件到第 N 号源文件的特征值。
(十六进制
数中的字母采取小写,不足两位的在前面补零。亦即 C/C++中使用”%02x”输出的格式。
)
以下(N - 1)行,每行两个整数,给出树上的一条边所连接的两个顶点。
[Output]
一行两个整数。依次为,选择的链的最小数目、编译的最小代价。两个数均以十进制形式输
出。
[Sample]
说明:最优方案为(1, 3), (2)。
说明:最优方案为(1, 3), (2, 4), (5)或(1, 3), (2, 5), (4)。
[Tips]
0 ≤ N ≤ 20,000。
这其实是到比较难的树形DP了。
首先我们可以证明链的数量最少是叶子节点数目(每条链必须有一个是个i另一个的爸爸)。
得到这个结论之后就好办很多,我们只需要求出最后代价最小就可以了。
现在我们来讨论怎么让代价最小。
解法写在了代码里:
#include<iostream> #include<cstdio> #include<cstdlib> #include<cstring> #include<algorithm> #define ll long long #define il inline #define db double #define max(a,b) ((a)>(b)?(a):(b)) #define min(a,b) ((a)<(b)?(a):(b)) using namespace std; il int gi() { int x=0,y=1; char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-') y=-1; ch=getchar(); } while(ch>='0'&&ch<='9') { x=x*10+ch-'0'; ch=getchar(); } return x*y; } int head[100045],cnt; struct edge { int next,to; }e[100045]; il void add(int from,int to) { e[++cnt].next=head[from]; e[cnt].to=to; head[from]=cnt; } int num[100045]; int du[100045]; int best[100045];//记录以i为根形成的子树异或和最优值 int f[100045][1<<8];//记录以i为根形成的一条链,该链异或和为j时,子树其他点的异或和最优值 void dfs(int x,int fa) { int r=head[x]; int ans=0; bool flag=0; while(r!=-1) { int now=e[r].to; if(now!=fa) { flag=1; dfs(now,x); ans+=best[now];//统计所有儿子最优异或和 } r=e[r].next; } if(!flag) { best[x]=num[x]; f[x][num[x]]=0; return; } r=head[x]; while(r!=-1) { int now=e[r].to; if(now!=fa) { for(int j=0;j<1<<8;j++)//模拟把它儿子连到它上面 { int r1=j^num[x]; int r2=ans-best[now]+f[now][j];//连上之后异或和,总的减去该儿子的值,加上该儿子链的值 f[x][r1]=min(f[x][r1],r2); } } r=e[r].next; } for(int i=0;i<1<<8;i++) best[x]=min(best[x],f[x][i]+i); } int main() { freopen("compiler.in","r",stdin); freopen("compiler.out","w",stdout); memset(best,127/3,sizeof(best)); memset(f,127/3,sizeof(f)); memset(head,-1,sizeof(head)); int n=gi(); char ch[5]; for(int i=1;i<=n;i++) { scanf("%s",ch); if(ch[0]>='0'&&ch[0]<='9') num[i]+=16*(ch[0]-'0'); if(ch[0]>='a'&&ch[0]<='f') num[i]+=16*(ch[0]-87); if(ch[1]>='0'&&ch[1]<='9') num[i]+=ch[1]-'0'; if(ch[1]>='a'&&ch[1]<='f') num[i]+=ch[1]-87; } int x,y; for(int i=1;i<n;i++) { x=gi(),y=gi(); du[x]++; du[y]++; add(x,y); add(y,x); } int tot=0; for(int i=1;i<=n;i++) if(du[i]==1&&i!=1) tot++; printf("%d ",tot); dfs(1,0); printf("%d\n",best[1]); return 0; }