题解 树点购买
一眼第一问是普及 DP
一眼第二问记录方案就好了,有点麻烦是联赛 DP
一眼第三问方案数直接在 DP 时算,比第二问还好写
开码!
一小时后
MD 我哪里写错了
欸我写个拍
欸这数有点离谱
欸这应该是个 \(t\) 啊为啥我写的 \(v\) 啊
欸过拍了
欸测极限数据
欸 0.5 s
欸我交
12:10:05
MD 这也能挂?
被卡常了?不像
dfs 写假了?一眼没假
记录方案假了?一眼……WC 我记忆化没写记忆!
12:12:31
喵的它过了
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f3f3f3f3f
#define N 1000010
#define fir first
#define sec second
#define pb push_back
#define ll long long
// #define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, k;
ll c[N];
int head[N], deg[N], ecnt;
const ll mod=998244353;
struct edge{int to, next;}e[N<<1];
inline void add(int s, int t) {e[++ecnt]={t, head[s]}; head[s]=ecnt;}
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
namespace force{
ll minn=INF, met;
bool able[50][2], ever[N], buy[N];
void dfs(int u, int fa) {
if (u!=1 && deg[u]==1) {
able[u][0]=buy[u];
able[u][1]=1;
return ;
}
ll cnt=0;
bool none=0;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa) continue;
dfs(v, u);
if (!able[v][0]) {
++cnt;
if (!able[v][1]) none=1;
}
}
if (buy[u]) able[u][0]=cnt<=1;
else able[u][0]=cnt==0;
able[u][1]=cnt<=1;
if (none) able[u][0]=able[u][1]=0;
}
void solve() {
int lim=1<<n;
for (int s=0; s<lim; ++s) {
ll sum=0;
for (int i=1; i<=n; ++i)
if (s&(1<<(i-1))) buy[i]=1, sum+=c[i];
else buy[i]=0;
memset(able, 0, sizeof(able));
dfs(1, 0);
if (able[1][0]) minn=min(minn, sum);
}
for (int s=0; s<lim; ++s) {
ll sum=0;
for (int i=1; i<=n; ++i)
if (s&(1<<(i-1))) buy[i]=1, sum+=c[i];
else buy[i]=0;
memset(able, 0, sizeof(able));
dfs(1, 0);
if (able[1][0] && sum==minn) {
++met;
for (int i=1; i<=n; ++i) if (buy[i]) ever[i]=1;
}
}
cout<<minn<<endl;
for (int i=1; i<=n; ++i) if (ever[i]) cout<<i<<' '; cout<<endl;
cout<<met<<endl;
}
}
namespace task1{
ll f[N][2], g[N][2];
vector<int> from[N];
bool vis[N][2], ans[N], use_self[N][2];
vector<pair<int, int>> back[N][2];
void dfs1(int u, int fa) {
if (deg[u]==1 && u!=1) {
f[u][0]=c[u]; f[u][1]=0;
g[u][0]=g[u][1]=1;
use_self[u][0]=1;
return ;
}
ll sum=0, maxn=-INF;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa) continue;
dfs1(v, u);
sum+=f[v][0];
maxn=max(maxn, f[v][0]-f[v][1]);
}
f[u][0]=min(sum, sum-maxn+c[u]);
f[u][1]=sum-maxn;
// assert(f[u][0]>f[u][1]);
ll dont_buy=(f[u][0]==sum), buy_one=0;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa) continue;
if (f[u][0]==sum) back[u][0].pb({v, 0}), dont_buy=dont_buy*g[v][0]%mod;
if (f[v][0]-f[v][1]==maxn) from[u].pb(v);
}
if (f[u][0]==sum-maxn+c[u]) {
use_self[u][0]=1;
if (from[u].size()==1) {
int v=from[u][0];
back[u][0].pb({v, 1}); buy_one=g[v][1];
// if (u==1) cout<<"buy_one: "<<buy_one<<endl;
for (int i=head[u]; ~i; i=e[i].next) if (e[i].to!=fa&&e[i].to!=v) back[u][0].pb({e[i].to, 0}), buy_one=buy_one*g[e[i].to][0]%mod;
}
else {
ll all=1;
for (auto& it:from[u]) back[u][0].pb({it, 1});
for (int i=head[u]; ~i; i=e[i].next) if (e[i].to!=fa) back[u][0].pb({e[i].to, 0}), all=all*g[e[i].to][0]%mod;
for (auto& it:from[u]) buy_one=(buy_one+g[it][1]*all%mod*qpow(g[it][0], mod-2))%mod;
}
}
// if (u==1) {
// cout<<"val: "<<f[u][0]<<' '<<sum<<' '<<sum-maxn+c[u]<<endl;
// cout<<"dont_buy: "<<dont_buy<<endl;
// cout<<"buy_one: "<<buy_one<<endl;
// cout<<"from_size: "<<from[u].size()<<endl;
// cout<<"from: "; for (auto it:from[u]) cout<<it<<' '; cout<<endl;
// // cout<<"g: "<<g[2][1]<<endl;
// }
g[u][0]=(dont_buy+buy_one)%mod;
if (from[u].size()==1) {
int v=from[u][0];
back[u][1].pb({v, 1}); g[u][1]=g[v][1];
for (int i=head[u]; ~i; i=e[i].next) if (e[i].to!=fa&&e[i].to!=v) back[u][1].pb({e[i].to, 0}), g[u][1]=g[u][1]*g[e[i].to][0]%mod;
}
else {
ll all=1;
for (auto& it:from[u]) back[u][1].pb({it, 1});
for (int i=head[u]; ~i; i=e[i].next) if (e[i].to!=fa) back[u][1].pb({e[i].to, 0}), all=all*g[e[i].to][0]%mod;
for (auto& it:from[u]) g[u][1]=(g[u][1]+g[it][1]*all%mod*qpow(g[it][0], mod-2))%mod;
}
}
void dfs2(int u, int fa, int t) {
if (vis[u][t]) return ;
vis[u][t]=1;
if (use_self[u][t]) ans[u]=1;
for (auto& it:back[u][t]) dfs2(it.fir, u, it.sec);
}
void solve() {
dfs1(1, 0); dfs2(1, 0, 0);
if (k>=1) printf("%lld\n", f[1][0]);
if (k>=2) {for (int i=1; i<=n; ++i) if (ans[i]) printf("%d ", i); printf("\n");}
if (k>=3) printf("%lld\n", (g[1][0]%mod+mod)%mod);
}
}
signed main()
{
freopen("purtree.in", "r", stdin);
freopen("purtree.out", "w", stdout);
n=read();
memset(head, -1, sizeof(head));
for (int i=1; i<=n; ++i) c[i]=read();
for (int i=1,u,v; i<n; ++i) {
u=read(); v=read();
add(u, v); add(v, u);
++deg[u]; ++deg[v];
}
k=read();
// force::solve();
task1::solve();
return 0;
}