题解 传染
首先一个暴力是枚举从每个点向它能感染的所有点连边
答案为缩点后零入度点的数量,正确性显然
考虑优化建边和 tarjan 的过程
考虑 tarjan 的搜索树
尝试构造一个序列 \(q\),使得对于一个 \(q_i\),在搜索树上它子树中的点 \(q_j\) 满足 \(j<i\)
构造方法是从每个未感染的点递归感染所有它能感染且未被感染的点,然后将这个点加入序列
直接做这个过程还是 \(O(n^2)\) 的
但是现在我们在找的实际上是离这个点距离 \(\leqslant r_i\) 的且未被感染的点
这个东西如果在点分树上找的话可以单调指针优化
将 \(i\) 到 \(j\) 在点分树上的路径拆成 \((i, lca), (lca, j)\) 两段
第一段可以枚举,第二段可以提前在每个lca预处理并排序
然后枚举第二段的过程就可以单调指针优化了
复杂度为 \(O(n\log^2 n)\),瓶颈在于给第二段排序
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f3f3f3f3f
#define N 300010
#define ll long long
#define fir first
#define sec second
#define pb push_back
#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n;
int head[N], ecnt;
bool del[N], vis[N];
vector<int> seq, seq2;
vector<pair<int, int>> cs[N], vs[N];
int r[N], siz[N], msiz[N], is[N], rot, ans;
struct edge{int to, next, val;}e[N<<1];
inline void add(int s, int t, int w) {e[++ecnt]={t, head[s], w}; head[s]=ecnt;}
void getrt(int u, int fa, int tot) {
siz[u]=1; msiz[u]=0;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (del[v] || v==fa) continue;
getrt(v, u, tot);
siz[u]+=siz[v];
msiz[u]=max(msiz[u], siz[v]);
}
msiz[u]=max(msiz[u], tot-siz[u]);
if (msiz[u]<msiz[rot]) rot=u;
}
void dfs(int u, int fa, int dis) {
vs[u].pb({rot, dis});
cs[rot].pb({dis, u});
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (del[v]||v==fa) continue;
dfs(v, u, dis+e[i].val);
}
}
void calc(int u) {
dfs(u, 0, 0);
sort(cs[u].begin(), cs[u].end());
}
void solve(int u) {
// cout<<"solve: "<<u<<endl;
del[u]=1;
calc(u);
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (del[v]) continue;
rot=0;
getrt(v, u, siz[v]);
solve(rot);
}
}
void infect(int u) {
vis[u]=1;
for (auto it:vs[u]) {
while (is[it.fir]<cs[it.fir].size()) {
pair<int, int> t=cs[it.fir][is[it.fir]++];
if (vis[t.sec]) continue;
if (it.sec+t.fir>r[u]) {--is[it.fir]; break;}
infect(t.sec);
}
}
seq.pb(u);
}
signed main()
{
freopen("infect.in", "r", stdin);
freopen("infect.out", "w", stdout);
n=read();
memset(head, -1, sizeof(head));
for (int i=1; i<=n; ++i) r[i]=read();
for (int i=1,u,v,w; i<n; ++i) {
u=read(); v=read(); w=read();
add(u, v, w); add(v, u, w);
}
msiz[0]=n;
getrt(1, 0, n);
solve(rot);
#if 0
cout<<"---cs---"<<endl;
for (int i=1; i<=n; ++i) {cout<<i<<": "; for (auto it:cs[i]) cout<<"("<<it.fir<<','<<it.sec<<")"<<' '; cout<<endl;}
cout<<"---vs---"<<endl;
for (int i=1; i<=n; ++i) {cout<<i<<": "; for (auto it:vs[i]) cout<<"("<<it.fir<<','<<it.sec<<")"<<' '; cout<<endl;}
#endif
for (int i=1; i<=n; ++i) if (!vis[i]) infect(i);
// cout<<"seq: "; for (auto it:seq) cout<<it<<' '; cout<<endl;
seq2=seq; seq.clear();
memset(vis, 0, sizeof(vis));
memset(is, 0, sizeof(is));
reverse(seq2.begin(), seq2.end());
for (auto it:seq2) if (!vis[it]) infect(it), ++ans;
printf("%lld\n", ans);
return 0;
}