【csp模拟赛6】树上统计-启发式合并,线段树合并
30%:暴力
40%:枚举L,R从L~n枚举,R每增大一个,更新需要的边(bfs实现)60%:枚举每条边,
计算每条边的贡献另外20%的数据:枚举每条边,计算每条边的贡献100%:对于每一条边统计
有多少个区间跨过这条边即可统计这一问题的对偶问题,有多少个区间没跨过会更方便使用启发式合并+
并查集统计子树内的,使用启发式合并+set统计子树外的
代码:
#include<cstdio> #include<cstdlib> #include<set> #include<vector> #include<iostream> #define LL long long #define int long long using namespace std; const int MAXN = 1e5 + 10; inline int read() { char c = getchar(); int x = 0, f = 1; while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();} while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar(); return x * f; } int N, siz[MAXN], son[MAXN], dsu[MAXN], Son, vis[MAXN], ds[MAXN]; vector<int> v[MAXN]; set<int> s; LL outt, inn, ans; void dfs(int x, int _fa) { siz[x] = 1; for(int i = 0; i < v[x].size(); i++) { int to = v[x][i]; if(to == _fa) continue; dfs(to, x); siz[x] += siz[to]; if(siz[to] > siz[son[x]]) son[x] = to; } } LL calc(LL x) { return x * (x - 1) / 2; } void Clear() { s.clear(); outt = calc(N); inn = 0; s.insert(0); s.insert(N + 1); } int find(int x) { return dsu[x] == x ? dsu[x] : dsu[x] = find(dsu[x]); } void solve(int x) { s.insert(x); set<int>::iterator s1, s2, it; s1 = s2 = it = s.find(x); s1--; s2++; outt -= calc((*s2) - (*s1) - 1); outt += calc((*s2) - (*it) - 1) + calc((*it) - (*s1) - 1); vis[x] = 1; if(vis[x - 1]) { int fx = find(x - 1), fy = find(x); inn += ds[fx] * ds[fy]; dsu[fx] = fy; ds[fy] += ds[fx]; } if(vis[x + 1]) { int fx = find(x + 1), fy = find(x); inn += ds[fx] * ds[fy]; dsu[fx] = fy; ds[fy] += ds[fx]; } } void Add(int x, int fa) { solve(x); for(int i = 0; i < v[x].size(); i++) { int to = v[x][i]; if(to == fa || to == Son) continue; Add(to, x); } } void Delet(int x, int fa) { vis[x] = 0; ds[x] = 1; dsu[x] = x; for(int i = 0; i < v[x].size(); i++) { int to = v[x][i]; if(to == fa) continue; Delet(to, x); } } void dfs2(int x, int fa, int opt) { for(int i = 0; i < v[x].size(); i++) { int to = v[x][i]; if(to == fa || (to == son[x])) continue; dfs2(to, x, 0); } if(son[x]) dfs2(son[x], x, 1); Son = son[x]; Add(x, fa); ans += calc(N) - inn - outt; if(opt == 0) Delet(x, fa), Clear(), Son = 0; } signed main() { #ifdef yilnr #else freopen("treecnt.in","r",stdin); freopen("treecnt.out","w",stdout); #endif N = read(); for(int i = 1; i <= N - 1; i++) { int x = read(), y = read(); v[x].push_back(y); v[y].push_back(x); } for(int i = 1; i <= N; i++) dsu[i] = i,ds[i] = 1; dfs(1,0); Clear(); dfs2(1,0,0); printf("%lld\n",ans); return 0; } /* 4 1 4 1 3 2 4 */