题解 Kutulu

传送门

一个想法是在 \(h\) 的时间内能出发多少个怪是固定的
那么最后主角停留的位置就是能让最多怪在路上的位置
经过题解证明一些人类智慧发现最优策略可以简化为直接走到该位置
考虑主角与怪的移动速度相同,可以证明只要能走到 \(u\),在 \(u\) 处的答案就是从一开始就在 \(u\) 的答案
那么答案就是所有与 1 距离 \(\leqslant h\) 的点处的最小值
于是可以枚举 \(u\),一个 \(v\) 的贡献是 \(a_v(\lfloor\frac{h-max(dis(u, v), 1)}{k}\rfloor+1)\)
但是我赛时考虑能出发的最后一个怪出发后还能走多久推了一个更麻烦的式子
那么现在可以做到 \(O(n^2)\)
然后 suffle+卡时可以得到 90 pts,多交几遍说不定能……

然后正解:题解做法巨难写,参考沈老师做法
对于一个 \(u\),考虑每个 \(v\) 出来的每个怪,第 \(i\) 个出发后的剩余时间是 \(h-(i-1)k\)
那么它能走到 \(u\) 的条件是 \(dis_{u, v}\leqslant h-(i-1)k\)
于是尝试对每个 \(i\in[1, \lfloor\frac{h}{k}\rfloor+1]\) 计算满足上式的 \(v\)\(\sum a_v\)
这个 \(\leqslant\) 的限制可以用一个经典的方法解决掉
那么现在可以做到 \(O(\frac{h}{k}n\log n)\)
再进一步的,发现所有取值是公差为 \(k\) 的等差数列,那么可以按对 \(k\) 取模后的余数分组计算
额外计算一下 取值 \(>\) 子树内最长链 的个数即可优化到 \(O(n\log n)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f3f3f3f3f
#define N 100010
#define fir first
#define sec second
#define pb push_back
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, k;
ll a[N], h;
int head[N], dep[N], fa[21][N], lg[N], ecnt;
struct edge{int to, next;}e[N<<1];
inline void add(int s, int t) {e[++ecnt]={t, head[s]}; head[s]=ecnt;}

namespace force{
	ll ans=INF;
	void dfs1(int tim, int pos, vector<int> cnt, ll sum);
	void dfs3(int u, int fa, vector<int>& cnt) {
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			cnt[u]+=cnt[v]; cnt[v]=0;
			dfs3(v, u, cnt);
		}
	}
	void dfs2(int tim, int pos, vector<int> cnt, ll sum) {
		dfs3(pos, 0, cnt); sum+=cnt[pos]; cnt[pos]=0;
		dfs1(tim+1, pos, cnt, sum);
	}
	void dfs1(int tim, int pos, vector<int> cnt, ll sum) {
		if (tim>h) {ans=min(ans, sum); return ;}
		if (tim%k==1%k) {for (int i=1; i<=n; ++i) cnt[i]+=a[i];}
		// cout<<"cnt: "; for (int i=1; i<=n; ++i) cout<<cnt[i]<<' '; cout<<endl;
		dfs2(tim, pos, cnt, sum);
		for (int i=head[pos],v; ~i; i=e[i].next) dfs2(tim, e[i].to, cnt, sum);
	}
	void solve() {
		dfs1(1, 1, vector<int>(n+1), 0);
		cout<<ans<<endl;
	}
}

namespace task1{
	vector<int> lst[N];
	ll f[310][310], ans=INF;
	vector<ll> g[310][310], cnt;
	void dfs3(int u, int fa) {
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			cnt[u]+=cnt[v]; cnt[v]=0;
			dfs3(v, u);
		}
	}
	void solve() {
		memset(f, 0x3f, sizeof(f));
		for (int i=1; i<=n; ++i) f[0][i]=0, g[0][i]=vector<ll>(n+1);
		for (int i=1; i<=n; ++i) {
			lst[i].pb(i);
			for (int j=head[i],v; ~j; j=e[j].next) lst[i].pb(e[j].to);
		}
		for (int i=1; i<=h; ++i) {
			for (int j=1; j<=n; ++j) {
				for (auto v:lst[j]) {
					cnt=g[i-1][v];
					if (i%k==1%k) {for (int t=1; t<=n; ++t) cnt[t]+=a[t];}
					dfs3(j, 0);
					if (f[i][j]>f[i-1][v]+cnt[j]) {
						f[i][j]=f[i-1][v]+cnt[j];
						cnt[j]=0; g[i][j]=cnt;
					}
				}
			}
		}
		for (int i=1; i<=n; ++i) ans=min(ans, f[h][i]);
		cout<<ans<<endl;
	}
}

namespace task2{
	int sta[N], top;
	ll dis[N], ans, tem;
	void dfs(int u, int fa, int d) {
		dis[u]=d;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			dfs(v, u, d+1);
		}
	}
	ll solve(int u) {
		// cout<<"solve: "<<u<<endl;	
		// 其实只有 dis(1, u)<=h 的 u 能计算贡献
		dfs(u, 0, 0);
		// cout<<"dis: "; for (int i=1; i<=n; ++i) cout<<dis[i]<<' '; cout<<endl;
		ll t=(h-1)/k*k+1;
		// cout<<"t: "<<t<<endl;
		ll rest=h-t, ans=0;
		for (int i=1; i<=n; ++i) dis[i]=max(dis[i]-rest-1, 0ll);
		for (int i=1; i<=n; ++i) if (dis[i]) {
			ans+=1ll*((dis[i]-1)/k+1)*a[i];
		}
		// cout<<ans<<endl;
		return ans;
	}
	void solve() {
		random_device seed;
		mt19937 rand(seed());
		for (int i=1; i<=n; ++i) ans+=(1ll*(h-1)/k+1)*a[i];
		// for (int i=1; i<=n; ++i) tem=max(tem, solve(i));
		for (int i=1; i<=n; ++i) sta[++top]=i;
		shuffle(sta+1, sta+top+1, rand);
		for (int i=1; i<=top&&clock()<1900000; ++i) tem=max(tem, solve(sta[i]));
		cout<<ans-tem<<endl;
	}
}

namespace task{
	bool del[N];
	map<int, ll> sum;
	ll ans[N], sta[N];
	vector<int> buc[N];
	int siz[N], msiz[N], mson[N], top, rot;
	void dfs(int u, int pa) {
		for (int i=1; i<21; ++i)
			if (dep[u]>=1<<i) fa[i][u]=fa[i-1][fa[i-1][u]];
			else break;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==pa) continue;
			dep[v]=dep[u]+1;
			fa[0][v]=u;
			dfs(v, u);
		}
	}
	int lca(int a, int b) {
		if (dep[a]<dep[b]) swap(a, b);
		while (dep[a]>dep[b]) a=fa[lg[dep[a]-dep[b]]-1][a];
		if (a==b) return a;
		for (int i=lg[dep[a]]-1; ~i; --i)
			if (fa[i][a]!=fa[i][b])
				a=fa[i][a], b=fa[i][b];
		return fa[0][a];
	}
	int dis(int a, int b) {return dep[a]+dep[b]-2*dep[lca(a, b)];}
	ll debug(int u) {
		ll ans=0;
		// cout<<"debug: "<<u<<endl;
		for (int i=1; i<=n; ++i) if (h>=dis(u, i)) ans+=a[i]*((h-max(dis(u, i), 1))/k+1);
		// cout<<ans<<endl;
		return ans;
	}
	void getrt(int u, int fa, int tot) {
		siz[u]=1; msiz[u]=0;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa||del[v]) continue;
			getrt(v, u, tot);
			siz[u]+=siz[v];
			msiz[u]=max(msiz[u], siz[v]);
		}
		msiz[u]=max(msiz[u], tot-siz[u]);
		if (msiz[u]<msiz[rot]) rot=u;
	}
	void getdis(int u, int fa, int dis) {
		while (top<dis) sta[++top]=0, buc[top].clear();
		sta[dis]+=a[u]; buc[dis].pb(u);
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa||del[v]) continue;
			getdis(v, u, dis+1);
		}
	}
	void calc(int u, int dlt, int op) {
		// cout<<"calc: "<<u<<' '<<dlt<<' '<<op<<endl;
		sta[top=0]=0; buc[0].clear();
		getdis(u, 0, dlt);
		// cout<<"sta: "; for (int i=0; i<=top; ++i) cout<<sta[i]<<' '; cout<<endl;
		for (int i=1; i<=top; ++i) sta[i]+=sta[i-1];
		// cout<<"pre: "; for (int i=0; i<=top; ++i) cout<<sta[i]<<' '; cout<<endl;
		// cout<<"buc: "; for (auto it:buc[0]) cout<<it<<' '; cout<<endl;
		// for (int i=0; i<k; ++i) sum[i]=0;
		sum.clear();
		for (int i=0; i<=top; ++i) sum[i%k]+=sta[i];
		// cout<<"sum: "; for (int i=0; i<k; ++i) cout<<sum[i]<<' '; cout<<endl;
		for (int i=0; i<=top; ++i) {
			// ll val=0;
			// for (int j=0; j<=(h-i)/k; ++j) val+=sta[min(h-i-j*k, (ll)top)];
			ll val=sum[(h-i)%k];
			if (h-i>top) val+=((h-i-top-1)/k+1)*sta[top];
			if (h-i<=top) sum[(h-i)%k]-=sta[h-i];
			// cout<<"val: "<<val<<endl;
			for (auto it:buc[i]) ans[it]+=val*op;
		}
	}
	void solve(int u) {
		// cout<<"solve: "<<u<<endl;
		del[u]=1;
		calc(u, 0, 1);
		for (int i=head[u],v; ~i; i=e[i].next) if (!del[e[i].to]) calc(e[i].to, 1, -1);
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (del[v]) continue;
			rot=0;
			getrt(v, 0, siz[v]);
			solve(rot);
		}
	}
	void solve() {
		for (int i=1; i<=n; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
		dep[1]=1;
		dfs(1, 0);
		ll tem=INF;
		// for (int i=1; i<=n; ++i) if (dis(1, i)<=h) tem=min(tem, debug(i));
		msiz[0]=n;
		getrt(1, 0, n);
		solve(rot);
		for (int i=1; i<=n; ++i) ans[i]=ans[i]+a[i]*(-h/k+(h-1)/k);
		// cout<<"ans: "; for (int i=1; i<=n; ++i) cout<<ans[i]<<' '; cout<<endl;
		// cout<<"tem: "; for (int i=1; i<=n; ++i) cout<<debug(i)<<' '; cout<<endl;
		for (int i=1; i<=n; ++i) if (dis(1, i)<=h) tem=min(tem, ans[i]);
		cout<<tem<<endl;
	}
}

signed main()
{
	freopen("b.in", "r", stdin);
	freopen("b.out", "w", stdout);

	n=read(); k=read(); h=read();
	memset(head, -1, sizeof(head));
	for (int i=1; i<=n; ++i) a[i]=read();
	for (int i=1,u,v; i<n; ++i) {
		u=read(); v=read();
		add(u, v); add(v, u);
	}
	// if (n<=4&&h<=4) force::solve();
	// else if (n<=300&&h<=300) task1::solve();
	// else task2::solve();
	task::solve();

	return 0;
}
posted @ 2022-05-07 21:33  Administrator-09  阅读(2)  评论(0编辑  收藏  举报