点分治做题小记
[COCI2020-2021#4] Janjetina
题目大意,统计点对 \((x,y)\) 满足 \(x\) 到 \(y\) 的路径上权值的最大值 \(w\) - \(x\) 到 \(y\) 的距离 \(l\) \(\ge k\) 的点对个数。
对于这种统计点对的问题,优先想到点分治。
首先,对于每个点到重心的路径,答案可以直接统计。
然后,对于经过重心的路径,我们分开考虑。
对于一棵子树中的节点 \(x\),我们要统计他和另外一棵子树的节点 \(y\) 的路径是否为答案。
是不是满足 \(max(w_x,w_y) - dist_x -dist_y \ge k\),则 \((x,y)\) 可以为答案。
我们分情况讨论:
- \(w_x > w_y\),则 \(w_x - dist_x -dist_y \ge k\),所以 \(k+dist_y \le w_x - dist_x\),这个可以用树状数组直接修改,查询前缀。
- \(w_x \le w_y\),则 \(w_y - dist_x -dist_y \ge k\),所以 \(w_y - dist_y \ge dist_x + k\),这似乎不是很好统计,我们取个相反数,是不是得到:\(dist_y-w_y \le -(dist_x+k)\),这就可以用树状数组维护了。
因为题目 \((x,y)\) 与 \((y,x)\) 算两种,所以答案要 \(\times 2\)
实现提示:二分一个 \(pos\) 为第一个大于等于 \(x\) 的位置,然后 \(pos\) 之前就是情况1,\(pos\) 以后的就是情况2。
讲的不清楚,还是看代码吧
#include<bits/stdc++.h>
#define int long long
#define inf 0x7f7f7f7f
using namespace std;
inline int read() {
int s=1,a=0;
char c=getchar();
while(!isdigit(c)) {
if(c=='-') s=-s;
c=getchar();
}
while(isdigit(c)) {
a=a*10+c-'0';
c=getchar();
}
return s*a;
}
int n;
const int N=2e5+8,Y=5e5+8;
struct edge {
int nxt,to,dis;
} e[N];
int head[N],idx;
int lowbit(int x) {
return x&(-x);
}
struct ta {
int c[Y*3];
void insert(int x,int v) {
x+=Y;
for(int i=x; i<=Y<<1; i+=lowbit(i)) {
c[i]+=v;
}
}
int query(int x) {
int ret=0;
x+=Y;
for(int i=x; i; i-=lowbit(i)) {
ret+=c[i];
}
return ret;
}
} wgj;
struct bi {
int c[Y*3];
void insert(int x,int v) {
x=Y-x;
for(int i=x; i<=Y<<1; i+=lowbit(i)) {
c[i]+=v;
}
}
int query(int x) {
int ret=0;
x=Y-x;
for(int i=x; i; i-=lowbit(i)) {
ret+=c[i];
}
return ret;
}
} lsp;
void add(int u,int v,int d) {
e[++idx].nxt=head[u];
e[idx].to=v;
e[idx].dis=d;
head[u]=idx;
}
int siz[N],maxn[N],rt,vis[N],qk,ans,sum;
void calcsiz(int u,int fa) {
siz[u]=1;
maxn[u]=0;
for(int i=head[u]; i; i=e[i].nxt) {
int v=e[i].to;
if(vis[v]||v==fa) continue;
calcsiz(v,u);
siz[u]+=siz[v];
maxn[u]=max(maxn[u],siz[v]);
}
maxn[u]=max(maxn[u],sum-siz[u]);
if(maxn[u]<maxn[rt]) rt=u;
}
int dist[N],dep[N],cnt,tot1;
struct node {
int mx,de,cnt;
bool operator < (const node b) const {
return mx<b.mx;
}
} tt[N],all[N],q0[N];
void calcdis(int u,int fa) {
tt[++cnt]=(node) {
dist[u],dep[u],dist[u]-dep[u]
};
if(dist[u]-dep[u]>=qk) ans++;
for(int i=head[u]; i; i=e[i].nxt) {
int v=e[i].to;
if(vis[v]||v==fa) continue;
dist[v]=max(dist[u],e[i].dis);
dep[v]=dep[u]+1;
calcdis(v,u);
}
}
void dfz(int u,int fa) {
vis[u]=1;
tot1=0;
for(int i=head[u]; i; i=e[i].nxt) {
int v=e[i].to;
if(vis[v]||v==fa) continue;
dist[v]=e[i].dis;
dep[v]=1;
calcdis(v,u);
sort(tt+1,tt+cnt+1);
int last=0;
for(int j=1; j<=cnt; j++) {
int pos=lower_bound(all+1,all+tot1+1,tt[j])-all-1;
for(int l=last+1; l<=pos; l++) {
wgj.insert(all[l].de+qk,1);
}
ans+=wgj.query(tt[j].cnt);
last=pos;
}
for(int j=1; j<=last; j++) wgj.insert(all[j].de+qk,-1);
last=tot1+1;
for(int j=cnt; j>=1; j--) {
int pos=lower_bound(all+1,all+tot1+1,tt[j])-all;
for(int l=last-1; l>=pos; l--) {
lsp.insert(all[l].cnt,1);
}
ans+=lsp.query(qk+tt[j].de);
last=pos;
}
for(int j=tot1; j>=last; j--) lsp.insert(all[j].cnt,-1);
merge(all+1,all+tot1+1,tt+1,tt+cnt+1,q0+1);
tot1+=cnt;
for(int j=1; j<=tot1; j++) all[j]=q0[j];
cnt=0;
}
// cout<<u<<" "<<ans<<endl;
for(int i=head[u]; i; i=e[i].nxt) {
int v=e[i].to;
if(vis[v]||v==fa) continue;
rt=0;
maxn[rt]=inf;
sum=siz[v];
calcsiz(v,u);
calcsiz(rt,-1);
dfz(rt,u);
}
}
signed main() {
n=read();
qk=read();
for(int i=1; i<n; i++) {
int u=read(),v=read(),d=read();
add(u,v,d);
add(v,u,d);
}
rt=0;
sum=n;
maxn[rt]=inf;
calcsiz(1,-1);
calcsiz(rt,-1);
// cout<<rt<<endl;
dfz(rt,-1);
printf("%lld\n",ans<<1);
return 0;
}
CF1174F Ehab and the Big Finale
这是一道交互题。
你有一棵树,你要找到一个点。
你可以如下询问:
- \(d \ u \ (1 \le u \le n)\),此询问可以告诉你 \(x\) 到 \(u\) 的距离。
- \(s \ u \ (1 \le u \le n)\),此询问可以告诉你 \(u \to x\) 的路径上的第二个点,注意,你询问的 \(u\) 必须是 \(x\) 的祖先,否则会 Wrong Answer。
你要在 36 次询问以内问出答案。
我们考虑点分治。
先询问 \(x\) 与 \(1\) 的距离。
对于一个重心 \(rt\),我们先询问它与 \(x\) 的距离。
那么我们可以分为两种情况:
- 第一种是 \(dep_x > dep_u\),我们再用一次 \(s\),找到 \(x\) 所在的子树,向子树分治即可。
- 另一种是 \(dep_x < dep_v\),我们就直接找到 \(u\) 的父亲,向上递归即可。
至于怎么判两种情况呢?
如果 \(dist_{u,v} + dep[u] = dist_{1,x}\) 的话,\(x\) 一定在 \(u\) 的下面。
否则就在 \(u\) 的上面。
这里有种特殊情况,就是当 \(dist_u = 0\) 的时候,你应该果断地输出答案。
还有就是 \(dist_u = 1\) 的时候,这时你也应该果断地使用 \(s\)(是第二种情况的时候),输出答案。
#include<bits/stdc++.h>
using namespace std;
inline int read() {
int s=1,a=0;
char c=getchar();
while(!isdigit(c)) {
if(c=='-') s=-s;
c=getchar();
}
while(isdigit(c)) {
a=a*10+c-'0';
c=getchar();
}
return s*a;
}
const int N=3e5+8;
int n;
vector <int> G[N];
int siz[N],maxn[N],vis[N],rt,d1,dep[N],sum;
void dfs(int u,int fa) {
for(auto v:G[u]) {
if(v==fa) continue;
dep[v]=dep[u]+1;
dfs(v,u);
}
}
void calcsiz(int u,int fa) {
siz[u]=1;
maxn[u]=0;
for(auto v:G[u]) {
if(vis[v]||v==fa) continue;
calcsiz(v,u);
siz[u]+=siz[v];
maxn[u]=max(maxn[u],siz[v]);
}
maxn[u]=max(maxn[u],sum-siz[u]);
if(maxn[u]<maxn[rt]) rt=u;
}
int dist[N];
void myloglast(int u,int fa) {
vis[u]=1;
printf("d %d\n",u);
fflush(stdout);
int di;
scanf("%d",&di);
if(di==0) {
printf("! %d\n",u);
// fflush(stdout);
exit(0);
}
if(dep[u]+di==d1&&dep[u]<d1) {
printf("s %d\n",u);
fflush(stdout);
int ss;
scanf("%d",&ss);
if(di==1) {
printf("! %d\n",ss);
// fflush(stdout);
exit(0);
}
rt=0;
maxn[rt]=n+9;
sum=siz[ss];
calcsiz(ss,u);
calcsiz(rt,-1);
myloglast(rt,u);
return;
}
else {
for(auto v:G[u]) {
if(vis[v]) continue;
if(dep[v]<dep[u]) {
rt=0;
maxn[rt]=n+9;
sum=siz[v];
calcsiz(v,u);
calcsiz(rt,-1);
myloglast(rt,u);
return;
}
}
}
}
int main() {
n=read();
for(int i=1; i<n; i++) {
int u=read(),v=read();
G[u].push_back(v),G[v].push_back(u);
}
rt=0;
dfs(1,0);
maxn[rt]=n+9;
sum=n;
printf("d %d\n",1);
fflush(stdout);
scanf("%d",&d1);
if(d1==0) {
puts("! 1");
return 0;
}
calcsiz(1,-1);
calcsiz(rt,-1);
myloglast(rt,1);
return 0;
}