“学佳澳杯”湖北文理学院第一届程序设计竞赛 H 戏团演出 OR 洛谷P4556 [Vani有约会]雨天的尾巴 (数据可加强)
做法一:树上差分+线段树合并
大致思路
如果不会线段树合并可以先看这篇博客
不会树上差分可以看这篇博客
因为线段树合并可以通过dfs回溯合并掉所有的子树的信息,而树上差分正好需要合并子树信息,而这题只需要把所有操作做完之后统计一下每个点的答案,很明显可以离线,所以考虑线段树合并+树上差分。虽然我也是看题解才知道,线段树合并也是现学了好久。
对于每次操作的[a, b, c],我们求出a,b的LCA(下文称为pa),然后在点a,b用权值线段树动态开点把c插进去(对于湖北文理学院的H题因为权值到了1e9所以需要先进行离散化,洛谷的P4556只需要把读入读进来存一下最大值(sz)就行了),对于pa和pa的父节点依次开权值线段树插入-c(如果pa的父节点是0号点的话就不需要往pa的父节点插了),我们发现对于每个操作我们最多在四个点用动态开点权值线段树,然后每个点最多需要开log(值域)个点,所以空间要求为N * 4 * log(1e5),开到N * 70就差不多了。线段树合并的时间复杂度是插入的点数(K)×log(K),也就是4e5 * log(4e5) 约等于 7.4e6 pushup之类的啥的常数不算了,这个 时间复杂度加一些小常数也是能过掉的。把所有操作读完之后进行一次dfs回溯合并子树上的权值线段树即可,边合并边记录答案。
偶然间在洛谷提交的时候发现了一个点可以卡掉自己代码但是提交了还能AC
反正我是把自己卡掉自己然后提交了还能AC
卡掉自己代码的输入如下
2 1
2 1
2 1
正确输出应为
0
1
我的输出为
1
1
那么为什么会导致这种错误呢
在使用线段树合并的时候如果只对两个儿子进行处理的话,假设我们只有一个点那么就会没有左右儿子,这么一来只从左右儿子合并的话就会出现错误。
这个错误是在交牛客这题时发现的错误,我把这题的ac代码粘到洛谷居然wa掉了,然后我调了一个多小时终于发现了问题所在。
先发在洛谷上AC的代码再发在牛客上AC的代码吧,都是修改完之后的代码
如下代码去掉第161行在洛谷上还是可以AC的,虽然我已经发私信给这个题的出题人了,但是不知道他看不看得见。
#include <iostream>
#include <cstring>
#include <iomanip>
#include <algorithm>
#include <stack>
#include <queue>
#include <numeric>
#include <cassert>
#include <bitset>
#include <cstdio>
#include <vector>
#include <unordered_set>
#include <cmath>
#include <map>
#include <unordered_map>
#include <set>
#include <deque>
#include <tuple>
#include <array>
#define all(a) a.begin(), a.end()
#define cnt0(x) __builtin_ctz(x)
#define endl '\n'
#define ll long long
#define ull unsigned long long
#define cntone(x) __builtin_popcount(x)
#define db double
#define fs first
#define ls(x) (tr[x].l)
#define rs(x) (tr[x].r)
#define cnt(x) (tr[x].cnt)
#define mx(x) (tr[x].v)
#define se second
#define AC main(void)
#define HYS std::ios::sync_with_stdio(false);std::cin.tie(0);std::cout.tie(0);
typedef std::pair<int, int > PII;
typedef std::pair<int, std::pair<int, int>> PIII;
typedef std::pair<ll, ll> Pll;
typedef std::pair<double, double> PDD;
using ld = double long;
const long double eps = 1e-9;
const int INF = 0x3f3f3f3f;
const int N = 1e5 + 10, M = 2e5 + 10;
int n , m, p;
int d1[] = {0, 0, 1, -1};
int d2[] = {1, -1, 0, 0};
int h[N], ne[N << 1], w[N << 1], e[N << 1], root[N];
int tot, idx, sz;
int rt = 1, depth[N], fa[N][19], bw = 18, q[N], cnt[N];
int ans[N];
struct node{
int l, r;
int v;
int cnt;
}tr[N * 70];
void bfs(){//初始化
memset(depth, 0x3f, sizeof depth);
int hh = 0, tt = 0;
q[0] = rt;//根节点入队
depth[rt] = 1, depth[0] = 0;
while(hh <= tt){
int f = q[hh ++];//取出队头
for(int i = h[f]; ~i; i = ne[i]){//遍历邻边
int j = e[i];
if(depth[j] > depth[f] + 1){
depth[j] = depth[f] + 1;
q[++ tt] = j;
fa[j][0] = f;
for(int k = 1; k <= bw; k ++)
fa[j][k] = fa[fa[j][k-1]][k-1];
}
}
}
}
int LCA(int a, int b){
if(depth[a] < depth[b]) std::swap(a, b);//保持a的是深层
//让a走到和b一样深的地方去
for(int i = bw; i >= 0; i --)
if(depth[fa[a][i]] >= depth[b])
a = fa[a][i];
if(a == b) return a;//说明b是a和b的公共祖先(不一定是最近公共祖先)
//让a和b一起往上跳到最近公共祖先的先前一个点
for(int i = bw; i >= 0; i -- ){
if(fa[a][i] != fa[b][i]){
a = fa[a][i];
b = fa[b][i];
}
}
return fa[a][0];
}
void pushup(int u){
if(tr[ls(u)].cnt >= tr[rs(u)].cnt){
if(!tr[ls(u)].cnt){
tr[u].cnt = tr[u].v = 0;
return ;
}
tr[u].cnt = tr[ls(u)].cnt;
tr[u].v = tr[ls(u)].v;
}
else{
tr[u].cnt = tr[rs(u)].cnt;
tr[u].v = tr[rs(u)].v;
}
}
inline void add(int a, int b){
ne[idx] = h[a], e[idx] = b, h[a] = idx ++;
}
//线段树合并操作 当没有一个儿子的时候会O1否则logn
inline int merge(int p, int q, int L, int R){
if(!p) return q;
if(!q) return p;
if(L == R){
tr[p].cnt += tr[q].cnt;
tr[p].v = L;
return p;
}
int mid = L + R >> 1;
tr[p].l = merge(tr[p].l, tr[q].l, L, mid);
tr[p].r = merge(tr[p].r, tr[q].r, mid + 1, R);
pushup(p);
return p;
}
inline void update(int u, int L, int R, int x, int sum){
if(L == R){
tr[u].cnt += sum;
tr[u].v = L;
return ;
}
int mid = L + R >> 1;
if(x <= mid){
if(!tr[u].l) tr[u].l = ++ tot;
update(tr[u].l, L, mid, x, sum);
}
else{
if(!tr[u].r) tr[u].r = ++ tot;
update(tr[u].r, mid + 1, R ,x, sum);
}
pushup(u);
}
inline void dfs(int u, int fa){
for(int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if(fa == j) continue;
dfs(j, u);
merge(root[u], root[j], 1, sz);
}
if(tr[root[u]].cnt)
ans[u] = tr[root[u]].v;
}
inline void solve(){
std::cin >> n >> m;
memset(h, -1, sizeof h);
for(int i = 1; i < n; i ++){
int a, b;
std::cin >> a >> b;
add(a, b);
add(b, a);
}
bfs();
std::vector<std::array<int, 3>> q(m);
for(int i = 0; i < m; i ++){
int a, b, c;
std::cin >> a >> b >> c;
q[i] = {a, b, c};
sz = std::max(c, sz);
}
for(int i = 1; i <= n; i ++) root[i] = ++ tot;
for(int i = 0; i < m; i ++){
auto &[a, b, v] = q[i];
int pa = LCA(a, b);
//如果a和b是同一个点就只对自己本身加一次,和自己的父亲减一次就行了
if(a == b){
update(root[a], 1, sz, v, 1);
if(fa[a][0])//如果父亲是0号点就不用进行操作了
update(root[fa[a][0]], 1, sz, v, -1);
continue;
}
update(root[a], 1, sz, v, 1);
update(root[b], 1, sz, v, 1);
update(root[pa], 1, sz, v, -1);
if(fa[pa][0]) update(root[fa[pa][0]], 1, sz, v, -1);
}
dfs(1, -1);
for(int i = 1; i <= n; i ++) std::cout << ans[i]<< '\n';
}
signed AC{
HYS
int _ = 1;
//std::cin >> _;
while(_ --)
solve();
return 0;
}
这个是在牛客上AC的代码
也不知道牛客高校比赛要找谁反馈数据
#include <iostream>
#include <cstring>
#include <iomanip>
#include <algorithm>
#include <stack>
#include <queue>
#include <numeric>
#include <cassert>
#include <bitset>
#include <cstdio>
#include <vector>
#include <unordered_set>
#include <cmath>
#include <map>
#include <unordered_map>
#include <set>
#include <deque>
#include <tuple>
#include <array>
#define all(a) a.begin(), a.end()
#define cnt0(x) __builtin_ctz(x)
#define endl '\n'
#define ll long long
#define ull unsigned long long
#define cntone(x) __builtin_popcount(x)
#define db double
#define fs first
#define ls(x) (tr[x].l)
#define rs(x) (tr[x].r)
#define cnt(x) (tr[x].cnt)
#define mx(x) (tr[x].v)
#define se second
#define AC main(void)
#define HYS std::ios::sync_with_stdio(false);std::cin.tie(0);std::cout.tie(0);
typedef std::pair<int, int > PII;
typedef std::pair<int, std::pair<int, int>> PIII;
typedef std::pair<ll, ll> Pll;
typedef std::pair<double, double> PDD;
using ld = double long;
const long double eps = 1e-9;
const int INF = 0x3f3f3f3f;
const int N = 2e5 + 10, M = 4e5 + 10;
int n , m, p;
int d1[] = {0, 0, 1, -1};
int d2[] = {1, -1, 0, 0};
int h[N], ne[N << 1], w[N << 1], e[N << 1], root[N];
int tot, idx, sz;
int rt = 1, depth[N], fa[N][19], bw = 18, q[N], cnt[N];
ll ans[N];
struct node{
int l, r;
int v;
int cnt;
}tr[N * 70];
void bfs(){//初始化
memset(depth, 0x3f, sizeof depth);
int hh = 0, tt = 0;
q[0] = rt;//根节点入队
depth[rt] = 1, depth[0] = 0;
while(hh <= tt){
int f = q[hh ++];//取出队头
for(int i = h[f]; ~i; i = ne[i]){//遍历邻边
int j = e[i];
if(depth[j] > depth[f] + 1){
depth[j] = depth[f] + 1;
q[++ tt] = j;
fa[j][0] = f;
for(int k = 1; k <= bw; k ++)
fa[j][k] = fa[fa[j][k-1]][k-1];
}
}
}
}
int LCA(int a, int b){
if(depth[a] < depth[b]) std::swap(a, b);//保持a的是深层
//让a走到和b一样深的地方去
for(int i = bw; i >= 0; i --)
if(depth[fa[a][i]] >= depth[b])
a = fa[a][i];
if(a == b) return a;//说明b是a和b的公共祖先(不一定是最近公共祖先)
//让a和b一起往上跳到最近公共祖先的先前一个点
for(int i = bw; i >= 0; i -- ){
if(fa[a][i] != fa[b][i]){
a = fa[a][i];
b = fa[b][i];
}
}
return fa[a][0];
}
void pushup(int u){
if(tr[ls(u)].cnt >= tr[rs(u)].cnt){
if(!tr[ls(u)].cnt){
tr[u].cnt = tr[u].v = 0;
return ;
}
tr[u].cnt = tr[ls(u)].cnt;
tr[u].v = tr[ls(u)].v;
}
else{
tr[u].cnt = tr[rs(u)].cnt;
tr[u].v = tr[rs(u)].v;
}
}
inline void add(int a, int b){
ne[idx] = h[a], e[idx] = b, h[a] = idx ++;
}
//线段树合并操作 当没有一个儿子的时候会O1否则logn
inline int merge(int p, int q, int L, int R){
if(!p) return q;
if(!q) return p;
if(L == R){
tr[p].cnt += tr[q].cnt;
tr[p].v = L;
return p;
}
int mid = L + R >> 1;
tr[p].l = merge(tr[p].l, tr[q].l, L, mid);
tr[p].r = merge(tr[p].r, tr[q].r, mid + 1, R);
pushup(p);
return p;
}
inline void update(int u, int L, int R, int x, int sum){
if(L == R){
tr[u].cnt += sum;
tr[u].v = L;
return ;
}
int mid = L + R >> 1;
if(x <= mid){
if(!tr[u].l) tr[u].l = ++ tot;
update(tr[u].l, L, mid, x, sum);
}
else{
if(!tr[u].r) tr[u].r = ++ tot;
update(tr[u].r, mid + 1, R ,x, sum);
}
pushup(u);
}
inline void dfs(int u, int fa){
for(int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if(fa == j) continue;
dfs(j, u);
merge(root[u], root[j], 1, sz);
}
if(tr[root[u]].cnt)
ans[u] = tr[root[u]].v;
}
inline void solve(){
std::cin >> n >> m;
memset(h, -1, sizeof h);
for(int i = 1; i < n; i ++){
int a, b;
std::cin >> a >> b;
add(a, b);
add(b, a);
}
bfs();
std::vector<std::array<int, 3>> q(m);
std::vector<int> nums;
for(int i = 0; i < m; i ++){
int a, b, c;
std::cin >> a >> b >> c;
q[i] = {a, b, c};
nums.push_back(c);
}
std::sort(all(nums));
nums.erase(std::unique(all(nums)), nums.end());
sz = nums.size();
for(int i = 1; i <= n; i ++) root[i] = ++ tot;
for(int i = 0; i < m; i ++){
auto &[a, b, v] = q[i];
int pa = LCA(a, b);
v = std::lower_bound(all(nums), v) - nums.begin() + 1;
update(root[a], 1, sz, v, 1);
update(root[b], 1, sz, v, 1);
update(root[pa], 1, sz, v, -1);
if(fa[pa][0]) update(root[fa[pa][0]], 1, sz, v, -1);
}
dfs(1, -1);
for(int i = 1; i <= n; i ++){
if(!ans[i]) std::cout << 0 << '\n';
else std::cout << nums[ans[i] - 1]<< '\n';
}
}
signed AC{
HYS
int _ = 1;
//std::cin >> _;
while(_ --)
solve();
return 0;
}
做法二:树链剖分+树上差分+动态开点权值线段树
大致思路:权值线段树值域为救济粮的最大编号。树剖分成logn个连续的区间,并且记录每个dfs序对应的原本的编号(若dfs序时id[u] = ++ cnt,那么rid[cnt] = u,rid就是反向标记的数组,后面需要用到),然后对在爬树剖的区间的时候给每个区间dfs序最小的地方和dfs序最大的地方加一动态开点权值线段树维护差分的值和数量,最后for遍历dfs序为1的位置到dfs序为n的位置把所有差分的值和数量取出来用一个权值线段树维护最大数量和对应的编号,每做完一个点就可以通过rid把答案存进对应的编号里面,最后for 1 -> n遍历输出答案。
因为是做完全部操作统一输出答案,所以可以离线做,这个做法的写法与上面自底向上进行线段树合并不同,因为dfs序是父亲的编号比儿子编号小,所以树链剖分需要在一段连续的区间里给dfs序最小和dfs序最大的值+1进行差分。树剖爬区间的时候给区间dfs序最小的点(父节点)和dfs序最大的点+1(差分思想就不赘述了)动态开点权值线段树进行单点修改,每个单点维护数量即可(不用维护对应的救济粮编号,权值线段树递归的时候L==R的时候L和R就是救济粮编号)。在取每个dfs序对应的所有维护差分的值和数量的时候我们可以自己开一个stk数组维护救济粮编号,sum数组维护每个救济粮的数量,这么一来就可以避免每次用vector进行emplace_back和clear操作,把每个dfs序对应的所有值从权值树取出来之后加入维护信息的权值线段树里面,每做完一个点就可以通过rid把答案存进对应的编号里面,最后for 1 -> n遍历输出答案。
因为直接粘的以前的树剖板子有些出现过的数组可能并没有用处
因为每个dfs序都对应一个权值线段树的根,所以在for读入的时候顺便给每个根开一个点,最后用root[n + 2]表示最后维护信息的权值线段树的根。
#include <iostream>
#include <cstring>
#include <iomanip>
#include <algorithm>
#include <stack>
#include <queue>
#include <numeric>
#include <cassert>
#include <bitset>
#include <cstdio>
#include <vector>
#include <unordered_set>
#include <cmath>
#include <map>
#include <unordered_map>
#include <set>
#include <deque>
#include <tuple>
#include <array>
#define all(a) a.begin(), a.end()
#define cnt0(x) __builtin_ctz(x)
#define endl '\n'
#define ll long long
#define ull unsigned long long
#define cntone(x) __builtin_popcount(x)
#define db double
#define fs first
#define ls(x) (tr[x].l)
#define rs(x) (tr[x].r)
#define cnt(x) (tr[x].cnt)
#define v(x) (tr[x].v)
#define se second
#define AC main(void)
#define HYS std::ios::sync_with_stdio(false);std::cin.tie(0);std::cout.tie(0);
typedef std::pair<int, int > PII;
typedef std::pair<int, std::pair<int, int>> PIII;
typedef std::pair<ll, ll> Pll;
typedef std::pair<double, double> PDD;
using ld = double long;
const long double eps = 1e-9;
const int INF = 0x3f3f3f3f;
const int N = 1e5 + 10, M = 2e5 + 10;
int n , m, p;
int d1[] = {0, 0, 1, -1};
int d2[] = {1, -1, 0, 0};
int w[N], h[N], e[M], ne[M], idx;
int id[N], nw[N], cnt;
int dep[N], sz[N], top[N], fa[N], son[N], mx, rid[N];
int root[N], tot;
int stk[N], scnt, sum[N];
int ans[N];
struct Node{
int l, r;
int v;
int cnt;
}tr[N * 80];
inline void init(){
memset(h, -1, sizeof h);
}
inline void add(int a, int b){
e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}
//统计子树大小 并且 找出重儿子
inline void dfs1(int u, int father, int depth){
dep[u] = depth, fa[u] = father, sz[u] = 1;
for (int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if (j == father) continue;
dfs1(j, u, depth + 1);
sz[u] += sz[j];
if (sz[son[u]] < sz[j]) son[u] = j;
}
}
//dfs序 nw标记dfs为cnt的时候对应的树上点的权值 top标记父亲是谁
inline void dfs2(int u, int t){
id[u] = ++ cnt, nw[cnt] = w[u], top[u] = t, rid[cnt] = u;
if (!son[u]) return;
dfs2(son[u], t);
for (int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if (j == fa[u] || j == son[u]) continue;
dfs2(j, j);
}
}
void update(int u, int L, int R, int x, int k){
if(L == R){
cnt(u) += k;
return ;
}
int mid = L + R >> 1;
if(x <= mid){
if(!ls(u)) ls(u) = ++ tot;
update(ls(u), L, mid, x, k);
}else{
if(!rs(u)) rs(u) = ++ tot;
update(rs(u), mid + 1, R, x, k);
}
}
//传进来想要爬的两个点
inline void update_path(int u, int v, int k){
while (top[u] != top[v]){
if (dep[top[u]] < dep[top[v]]) std::swap(u, v);
update(root[id[top[u]]], 1, mx, k, 1);
if(id[u] + 1 <= n)
update(root[id[u] + 1], 1, mx, k, -1);
u = fa[top[u]];
}
if (dep[u] < dep[v]) std::swap(u, v);
update(root[id[v]], 1, mx, k, 1);
if(id[u] + 1 <= n)
update(root[id[u] + 1], 1, mx , k, -1);
}
inline void Init(){
dfs1(1, -1, 1);
dfs2(1, 1);
}
inline void pushup(int u){
if(cnt(ls(u)) >= cnt(rs(u))){
if(!cnt(ls(u))) cnt(u) = v(u) = 0;
else{
cnt(u) = cnt(ls(u));
v(u) = v(ls(u));
}
}else{
cnt(u) = cnt(rs(u));
v(u) = v(rs(u));
}
}
inline void query(int u, int L, int R){
if(!u) return ;
if(L == R){
sum[scnt] = cnt(u);
stk[scnt ++] = L;
return ;
}
int mid = L + R >> 1;
query(ls(u), L, mid);
query(rs(u), mid + 1, R);
}
inline void modify(int u, int L, int R, int x, int k){
if(L == R){
cnt(u) += k;
v(u) = L;
return ;
}
int mid = L + R >> 1;
if(x <= mid){
if(!ls(u)) ls(u) = ++ tot;
modify(ls(u), L, mid, x, k);
}else{
if(!rs(u)) rs(u) = ++ tot;
modify(rs(u), mid + 1, R, x, k);
}
pushup(u);
}
inline void solve(){
std::cin >> n >> m;
memset(h, -1, sizeof h);
for(int i = 1; i < n; i ++){
root[i] = ++ tot;
int a, b;
std::cin >> a >> b;
add(a, b);
add(b, a);
}
root[n] = ++ tot;
Init();
std::vector<std::array<int, 3>> q(m);
for(int i = 0; i < m; i ++){
int a, b, c;
std::cin >> a >> b >> c;
q[i] = {a, b, c};
//找到值域范围
mx = std::max(c, mx);
}
for(int i = 0; i < m; i ++){
auto &[l, r, v] = q[i];
update_path(l, r, v);
}
auto &rt = root[n + 2];
rt = ++ tot;
for(int i = 1; i <= n; i ++){
scnt = 0;
query(root[i], 1, mx);
for(int j = 0; j < scnt; j ++){
modify(rt, 1, mx, stk[j], sum[j]);
}
if(cnt(rt)) ans[rid[i]] = v(rt);
}
for(int i = 1; i <= n; i ++)
std::cout << ans[i] << '\n';
}
signed AC{
HYS
int _ = 1;
//std::cin >> _;
while(_ --)
solve();
return 0;
}