题解 树点购买

传送门

一眼第一问是普及 DP
一眼第二问记录方案就好了,有点麻烦是联赛 DP
一眼第三问方案数直接在 DP 时算,比第二问还好写
开码!

一小时后

MD 我哪里写错了
欸我写个拍
欸这数有点离谱
欸这应该是个 \(t\) 啊为啥我写的 \(v\)
欸过拍了
欸测极限数据
欸 0.5 s
欸我交

12:10:05

MD 这也能挂?
被卡常了?不像
dfs 写假了?一眼没假
记录方案假了?一眼……WC 我记忆化没写记忆!

12:12:31

喵的它过了

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f3f3f3f3f
#define N 1000010
#define fir first
#define sec second
#define pb push_back
#define ll long long
// #define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, k;
ll c[N];
int head[N], deg[N], ecnt;
const ll mod=998244353;
struct edge{int to, next;}e[N<<1];
inline void add(int s, int t) {e[++ecnt]={t, head[s]}; head[s]=ecnt;}
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

namespace force{
	ll minn=INF, met;
	bool able[50][2], ever[N], buy[N];
	void dfs(int u, int fa) {
		if (u!=1 && deg[u]==1) {
			able[u][0]=buy[u];
			able[u][1]=1;
			return ;
		}
		ll cnt=0;
		bool none=0;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			dfs(v, u);
			if (!able[v][0]) {
				++cnt;
				if (!able[v][1]) none=1;
			}
		}
		if (buy[u]) able[u][0]=cnt<=1;
		else able[u][0]=cnt==0;
		able[u][1]=cnt<=1;
		if (none) able[u][0]=able[u][1]=0;
	}
	void solve() {
		int lim=1<<n;
		for (int s=0; s<lim; ++s) {
			ll sum=0;
			for (int i=1; i<=n; ++i)
				if (s&(1<<(i-1))) buy[i]=1, sum+=c[i];
				else buy[i]=0;
			memset(able, 0, sizeof(able));
			dfs(1, 0);
			if (able[1][0]) minn=min(minn, sum);
		}
		for (int s=0; s<lim; ++s) {
			ll sum=0;
			for (int i=1; i<=n; ++i)
				if (s&(1<<(i-1))) buy[i]=1, sum+=c[i];
				else buy[i]=0;
			memset(able, 0, sizeof(able));
			dfs(1, 0);
			if (able[1][0] && sum==minn) {
				++met;
				for (int i=1; i<=n; ++i) if (buy[i]) ever[i]=1;
			}
		}
		cout<<minn<<endl;
		for (int i=1; i<=n; ++i) if (ever[i]) cout<<i<<' '; cout<<endl;
		cout<<met<<endl;
	}
}

namespace task1{
	ll f[N][2], g[N][2];
	vector<int> from[N];
	bool vis[N][2], ans[N], use_self[N][2];
	vector<pair<int, int>> back[N][2];
	void dfs1(int u, int fa) {
		if (deg[u]==1 && u!=1) {
			f[u][0]=c[u]; f[u][1]=0;
			g[u][0]=g[u][1]=1;
			use_self[u][0]=1;
			return ;
		}
		ll sum=0, maxn=-INF;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			dfs1(v, u);
			sum+=f[v][0];
			maxn=max(maxn, f[v][0]-f[v][1]);
		}
		f[u][0]=min(sum, sum-maxn+c[u]);
		f[u][1]=sum-maxn;
		// assert(f[u][0]>f[u][1]);
		ll dont_buy=(f[u][0]==sum), buy_one=0;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			if (f[u][0]==sum) back[u][0].pb({v, 0}), dont_buy=dont_buy*g[v][0]%mod;
			if (f[v][0]-f[v][1]==maxn) from[u].pb(v);
		}
		if (f[u][0]==sum-maxn+c[u]) {
			use_self[u][0]=1;
			if (from[u].size()==1) {
				int v=from[u][0];
				back[u][0].pb({v, 1}); buy_one=g[v][1];
				// if (u==1) cout<<"buy_one: "<<buy_one<<endl;
				for (int i=head[u]; ~i; i=e[i].next) if (e[i].to!=fa&&e[i].to!=v) back[u][0].pb({e[i].to, 0}), buy_one=buy_one*g[e[i].to][0]%mod;
			}
			else {
				ll all=1;
				for (auto& it:from[u]) back[u][0].pb({it, 1});
				for (int i=head[u]; ~i; i=e[i].next) if (e[i].to!=fa) back[u][0].pb({e[i].to, 0}), all=all*g[e[i].to][0]%mod;
				for (auto& it:from[u]) buy_one=(buy_one+g[it][1]*all%mod*qpow(g[it][0], mod-2))%mod;
			}
		}
		// if (u==1) {
		// 	cout<<"val: "<<f[u][0]<<' '<<sum<<' '<<sum-maxn+c[u]<<endl;
		// 	cout<<"dont_buy: "<<dont_buy<<endl;
		// 	cout<<"buy_one: "<<buy_one<<endl;
		// 	cout<<"from_size: "<<from[u].size()<<endl;
		// 	cout<<"from: "; for (auto it:from[u]) cout<<it<<' '; cout<<endl;
		// 	// cout<<"g: "<<g[2][1]<<endl;
		// }
		g[u][0]=(dont_buy+buy_one)%mod;
		if (from[u].size()==1) {
			int v=from[u][0];
			back[u][1].pb({v, 1}); g[u][1]=g[v][1];
			for (int i=head[u]; ~i; i=e[i].next) if (e[i].to!=fa&&e[i].to!=v) back[u][1].pb({e[i].to, 0}), g[u][1]=g[u][1]*g[e[i].to][0]%mod;
		}
		else {
			ll all=1;
			for (auto& it:from[u]) back[u][1].pb({it, 1});
			for (int i=head[u]; ~i; i=e[i].next) if (e[i].to!=fa) back[u][1].pb({e[i].to, 0}), all=all*g[e[i].to][0]%mod;
			for (auto& it:from[u]) g[u][1]=(g[u][1]+g[it][1]*all%mod*qpow(g[it][0], mod-2))%mod;
		}
	}
	void dfs2(int u, int fa, int t) {
		if (vis[u][t]) return ;
		vis[u][t]=1;
		if (use_self[u][t]) ans[u]=1;
		for (auto& it:back[u][t]) dfs2(it.fir, u, it.sec);
	}
	void solve() {
		dfs1(1, 0); dfs2(1, 0, 0);
		if (k>=1) printf("%lld\n", f[1][0]);
		if (k>=2) {for (int i=1; i<=n; ++i) if (ans[i]) printf("%d ", i); printf("\n");}
		if (k>=3) printf("%lld\n", (g[1][0]%mod+mod)%mod);
	}
}

signed main()
{
	freopen("purtree.in", "r", stdin);
	freopen("purtree.out", "w", stdout);

	n=read();
	memset(head, -1, sizeof(head));
	for (int i=1; i<=n; ++i) c[i]=read();
	for (int i=1,u,v; i<n; ++i) {
		u=read(); v=read();
		add(u, v); add(v, u);
		++deg[u]; ++deg[v];
	}
	k=read();
	// force::solve();
	task1::solve();

	return 0;
}
posted @ 2022-04-02 19:08  Administrator-09  阅读(3)  评论(0编辑  收藏  举报