模拟赛:树和森林(lct.cpp) (树形DP,换根DP好题)

题面

屏幕快照 2020-10-31 下午2.07.11.png

屏幕快照 2020-10-31 下午2.07.40.png

屏幕快照 2020-10-31 下午2.07.51.png

题解

先解决第一个子问题吧,它才是难点

Subtask_1

我们可以先用一个简单的树形DP处理出每棵树内部的dis和,记为dp0[i],

然后再用一个换根的树形DP处理出每棵树内点 i 到树内每个点的距离和,记为dp[i],

好,现在分两个连通块跟三个连通块两种情况讨论

两个连通块

把两棵树A,B合并到一起,我们得先确定两个连接的点,

若其分别为 i,j,不难发现答案就是 dp0[A] + dp0[B] + dp[i] * size[B] + dp[j] * size[A] + size[A] * size[B]

其中,dp0[A],dp0[B],size[A],size[B]都是确定的,那么当答案取最大的时候,dp[i]、dp[j]一定分别都取最大

所以在两棵树中找dp值最大的两个点 i , j 就行了。

三个连通块

两条边连接三棵树,幸运的是,大的情况只有三种(A-B-C , A-C-B , B-A-C)

其中一棵树一定连了两条边,且不一定是同一个点连出去,不妨设中间那棵为B

注意到中间那条绿色的路径没?也就是说中间是可能有两个点的

若连接的两个点分别为 i,j,那么仔细开动一下脑筋,会发现总的贡献是

dp[A]_{max}*(size[B]+size[C])+dp[C]_{max}*(size[A]+size[B]) \\ + dp[i]*size[A]+dp[j]*size[C]+size[A]*size[C]*(dis(i,j)+2) \\ +(size[A]+size[C])*size[B](red \;line)

其中随着 i,j 变化的只有  dp[i]*size[A]+dp[j]*size[C]+size[A]*size[C]*(dis(i,j)+2)

因此,若以B树上每个节点为 i 考虑,我们可以用一个简单的换根DP求出上式的最大值,记为dp2[i],然后再求出B树中最大的dp2[i],加到上面三排的式子中,答案就出来了。

Subtask_2

子问题二其实更简单,不用管它的第二个条件,因为它只有0~1个解。

原图是个森林,是很多树组成,所以先考虑叶子结点,叶子结点如果是黑的,它的父边就不能删,如果是白的,它的父边就必须删;然后消除它父边的影响,再把叶子删去。这样一来,又有新的叶子,重复考虑……可以发现,最后要么无解,要么只有一个解,而且很好输出。

CODE

#include<map>
#include<queue>
#include<cmath>
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define MAXN 100005
#define LL unsigned long long
#define DB double
#define ENDL putchar('\n')
#define lowbit(x) ((-x)&(x))
LL read() {
    LL f = 1,x = 0;char s = getchar();
    while(s < '0' || s > '9') {if(s=='-')f=-f;s = getchar();}
    while(s >= '0' && s <= '9') {x=x*10+(s-'0');s=getchar();}
    return f*x;
}
const int MOD = 998244353;
int n,m,i,j,s,o,k;
vector<int> g[MAXN];
vector<int> id[MAXN];
char cl[MAXN];
int U[MAXN],V[MAXN];
int siz[MAXN],rt[10],cnt;
LL dp1[MAXN],dp2[MAXN],ma[MAXN],sm[MAXN],dp[MAXN];
LL dpu[MAXN],dpd[MAXN],dpa[MAXN];
vector<LL> pre[MAXN],suf[MAXN];
bool vis[MAXN],f[MAXN],ad[MAXN];
void dfs1(int x,int fa) {
	vis[x] = 1;
	siz[x] = 1;
	dp1[x] = 0;
	for(int i = 0;i < g[x].size();i ++) {
		int y = g[x][i];
		if(y != fa) {
			dfs1(y,x);
			siz[x] += siz[y];
			dp1[x] += dp1[y] + (LL)siz[y];
		}
	}return ;
}
void dfs2(int x,int fa,int n) {
	dp2[x] = 0;
	if(fa) {
		dp2[x] = (dp2[fa] + (dp1[fa] - (dp1[x] + (LL)siz[x]))) + (n-siz[x]);
	}
	dp[x] = dp1[x] + dp2[x];
	ma[x] = dp[x];
	sm[x] = siz[x] * (n-siz[x]);
	for(int i = 0;i < g[x].size();i ++) {
		int y = g[x][i];
		if(y != fa) {
			dfs2(y,x,n);
			ma[x] = max(ma[x],ma[y]);
			sm[x] += sm[y];
		}
	}
	return ;
}
void dfs3(int x,int fa,int sizA,int sizB) {
	dpd[x] = dp[x] * (LL)sizB + sizA *2ll* sizB;
	LL pr = 0;
	for(int i = 0;i < g[x].size();i ++) {
		pre[x].push_back(pr);
		int y = g[x][i];
		if(y != fa) {
			dfs3(y,x,sizA,sizB);
			pr = max(pr,dpd[y] + sizA *1ll* sizB);
			dpd[x] = max(dpd[x],dpd[y] + sizA *1ll* sizB);
		}
	}
	pr = 0;
	for(int i = (int)g[x].size()-1;i >= 0;i --) {
		suf[x].push_back(pr);
		int y = g[x][i];
		if(y != fa) {
			pr = max(pr,dpd[y] + sizA *1ll* sizB);
		}
	}
	return ;
}
void dfs4(int x,int fa,int sizA,int sizB,int ad1,int ad2) {
	dpu[x] = dp[x] * (LL)sizB + sizA *2ll* sizB;
	if(fa) {
		dpu[x] = max(dpu[x],max(dpu[fa],max(pre[fa][ad1],suf[fa][ad2])) + sizA *1ll* sizB);
	}
	dpa[x] = max(dpd[x],dpu[x]) + dp[x] * (LL)sizA;
	for(int i = 0;i < g[x].size();i ++) {
		int y = g[x][i];
		if(y != fa) {
			dfs4(y,x,sizA,sizB,i,(int)g[x].size()-1-i);
			dpa[x] = max(dpa[x],dpa[y]);
		}
	}return ;
}
bool dfs5(int x,int fa,int ed) {
	for(int i = 0;i < g[x].size();i ++) {
		int y = g[x][i],idn = id[x][i];
		if(y != fa) {
			bool cg = dfs5(y,x,idn);
			f[x] ^= cg;
		}
	}
	if(f[x]) ad[ed] = 1;
	return f[x];
}
int main() {
//	freopen("lct.in","r",stdin);
//	freopen("lct.out","w",stdout);
	n = read();m = read();
	scanf("%s",cl + 1);
	for(int i = 1;i <= m;i ++) {
		s = read();o = read();
		U[i] = s;V[i] = o;
		g[s].push_back(o);
		g[o].push_back(s);
		id[s].push_back(i);
		id[o].push_back(i);
	}
	for(int i = 1;i <= n;i ++) {
		if(!vis[i]) {
			dfs1(i,0);
			dfs2(i,0,siz[i]);
			rt[++ cnt] = i;
		}
	}
	LL ans1 = 0;
	if(cnt == 2) {
		ans1 = ma[rt[1]] * siz[rt[2]] + ma[rt[2]] * siz[rt[1]] + siz[rt[2]] *1ll* siz[rt[1]];
		ans1 += sm[rt[1]] + sm[rt[2]];
	}
	else if(cnt == 3) {
		LL DP1 = ma[rt[1]] * (LL)(n-siz[rt[1]]),D1 = siz[rt[1]] *1ll* (n-siz[rt[1]]);
		LL DP2 = ma[rt[2]] * (LL)(n-siz[rt[2]]),D2 = siz[rt[2]] *1ll* (n-siz[rt[2]]);
		LL DP3 = ma[rt[3]] * (LL)(n-siz[rt[3]]),D3 = siz[rt[3]] *1ll* (n-siz[rt[3]]);
		dfs3(rt[1],0,siz[rt[2]],siz[rt[3]]);
		dfs4(rt[1],0,siz[rt[2]],siz[rt[3]],0,0);
		ans1 = max(ans1,dpa[rt[1]] + DP2 + DP3 + D1);
		
		dfs3(rt[2],0,siz[rt[1]],siz[rt[3]]);
		dfs4(rt[2],0,siz[rt[1]],siz[rt[3]],0,0);
		ans1 = max(ans1,dpa[rt[2]] + DP1 + DP3 + D2);
		
		dfs3(rt[3],0,siz[rt[1]],siz[rt[2]]);
		dfs4(rt[3],0,siz[rt[1]],siz[rt[2]],0,0);
		ans1 = max(ans1,dpa[rt[3]] + DP1 + DP2 + D3);
		
		ans1 += sm[rt[1]] + sm[rt[2]] + sm[rt[3]];
	}
	printf("%lld\n",ans1);
	for(int i = 1;i <= n;i ++) f[i] = (cl[i] == 'B' ? 1:0);
	bool flag = 0;
	for(int i = 1;i <= cnt;i ++) {
		flag |= dfs5(rt[i],0,0);
	}
	if(flag) {
		printf("-1\n");
	}
	else {
		int cn = 0;
		for(int i = 1;i <= m;i ++) {
			if(ad[i]) cn ++;
		}
		printf("%d\n",cn);
		for(int i = 1;i <= m;i ++) {
			if(ad[i]) printf("%d ",i);
		}
		ENDL;
	}
    return 0;
}

 

posted @ 2020-10-31 19:44  DD_XYX  阅读(56)  评论(0编辑  收藏  举报