题解 签到题

传送门

一道完整的签到题应该包括一个小时的读错题和一个小时的写代码(

建出笛卡尔树(建普通树也行)
尝试对每个点维护出这个点到根的路径上所有右父亲的答案
查询的时候最优的那个什么点就是 lca 的第一个右父亲
然后差分相减即可
发现差分相减的时候需要用到一个点的精确权值
所以还要维护一下
修改的时候涉及到一个链加所以要树剖一下
复杂度是 \(O(n\log^2 n)\) 的,通过卡常完成了对 \(O(n\log n)\) 做法的一个吊打

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 200010
#define fir first
#define sec second
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, q;
ll w[N];
int a[N];

namespace force{
	int cnt[N];
	set<int> getset(int x) {
		set<int> ans;
		int pre=x;
		ans.insert(x);
		for (int i=x+1; i<=n; ++i) if (a[i]>a[pre])
			ans.insert(i), pre=i;
		return ans;
	}
	void solve() {
		for (int i=1,x,y,v; i<=q; ++i) {
			if (read()&1) {
				x=read(); v=read();
				for (int y=1; y<=n; ++y) {
					set<int> tem=getset(y);
					while (tem.size()>2) tem.erase(tem.find(*tem.rbegin()));
					if (tem.find(x)!=tem.end()) w[y]+=v;
				}
			}
			else {
				x=read(); y=read();
				for (int i=1; i<=n; ++i) cnt[i]=0;
				for (int i=max(x, y)+1; i<=n; ++i) ++cnt[i];
				set<int> tem;
				tem=getset(x); for (auto it:tem) ++cnt[it];
				tem=getset(y); for (auto it:tem) ++cnt[it];
				for (int z=1; z<=n; ++z) if (cnt[z]==3) {
					tem=getset(x);
					for (auto it:getset(y)) tem.insert(it);
					while (tem.size() && *tem.rbegin()>z) tem.erase(tem.find(*tem.rbegin()));
					ll ans=0;
					for (auto it:tem) ans+=w[it];
					printf("%lld\n", ans);
					goto jump;
				}
				printf("?\n");
				jump: ;
			}
		}
	}
}

namespace task1{
	ll bit[N];
	pair<int, int> sta[N];
	int ls[N], rs[N], id[N], siz[N], lg[N], dep[N], fa[N], top[N], btm[N], mson[N], stop, rot, tot;
	inline void add(int i, ll dat) {for (; i<=n; i+=i&-i) bit[i]+=dat;}
	inline void add(int l, int r, int dat) {
		++r;
		while (l<r) bit[l]+=dat, l+=l&-l;
		while (r<l) bit[r]-=dat, r+=r&-r;
	}
	inline ll query(int i) {ll ans=0; for (; i; i-=i&-i) ans+=bit[i]; return ans;}
	void dfs1(int u, int t) {
		top[u]=t; id[u]=++tot; siz[u]=1; btm[u]=u;
		// for (int i=1; dep[u]>=1<<i; ++i)
		// 	fa[i][u]=fa[i-1][fa[i-1][u]];
		if (ls[u]) dep[ls[u]]=dep[u]+1, fa[ls[u]]=u, dfs1(ls[u], u), siz[u]+=siz[ls[u]];
		if (rs[u]) dep[rs[u]]=dep[u]+1, fa[rs[u]]=u, dfs1(rs[u], t), siz[u]+=siz[rs[u]], btm[u]=btm[rs[u]];
		mson[u]=siz[ls[u]]>siz[rs[u]]?ls[u]:rs[u];
	}
	// int lca(int a, int b) {
	// 	if (dep[a]<dep[b]) swap(a, b);
	// 	while (dep[a]>dep[b]) a=fa[lg[dep[a]-dep[b]]-1][a];
	// 	if (a==b) return a;
	// 	for (int i=lg[dep[a]]-1; ~i; --i)
	// 		if (fa[i][a]!=fa[i][b])
	// 			a=fa[i][a], b=fa[i][b];
	// 	return fa[0][a];
	// }
	namespace trdiv{
		ll bit[N];
		int top[N], id[N], tot;
		inline void add(int i, ll dat) {for (; i<=n; i+=i&-i) bit[i]+=dat;}
		inline void add(int l, int r, int dat) {
			++r;
			while (l<r) bit[l]+=dat, l+=l&-l;
			while (r<l) bit[r]-=dat, r+=r&-r;
		}
		inline ll query(int i) {ll ans=0; for (; i; i-=i&-i) ans+=bit[i]; return ans;}
		void dfs2(int u, int t) {
			top[u]=t; id[u]=++tot;
			if (!mson[u]) return ;
			dfs2(mson[u], t);
			if (ls[u]&&ls[u]!=mson[u]) dfs2(ls[u], ls[u]);
			if (rs[u]&&rs[u]!=mson[u]) dfs2(rs[u], rs[u]);
		}
		void init() {for (int i=1; i<=n; ++i) add(id[i], id[i], w[i]);}
		void upd(int x, int y, ll val) {
			// cout<<"upd: "<<x<<' '<<y<<' '<<val<<endl;
			while (top[x]!=top[y]) {
				if (dep[top[x]]<dep[top[y]]) swap(x, y);
				add(id[top[x]], id[x], val);
				x=fa[top[x]];
			}
			if (dep[x]>dep[y]) swap(x, y);
			add(id[x], id[y], val);
		}
		int lca(int x, int y) {
			// cout<<"upd: "<<x<<' '<<y<<' '<<val<<endl;
			while (top[x]!=top[y]) {
				if (dep[top[x]]<dep[top[y]]) swap(x, y);
				x=fa[top[x]];
			}
			return dep[x]<dep[y]?x:y;
		}
		ll qval(int x) {return query(id[x]);}
	}
	void solve() {
		// for (int i=1; i<=n; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
		for (int i=1; i<=n; ++i) {
			pair<int, int> now={i, a[i]};
			int k=stop;
			while (k && sta[k].sec<now.sec) --k;
			if (k) rs[sta[k].fir]=now.fir;
			if (k!=stop) ls[now.fir]=sta[k+1].fir;
			sta[stop=++k]=now;
		}
		rot=sta[1].fir;
		dep[rot]=1; dfs1(rot, 0);
		trdiv::dfs2(rot, rot); trdiv::init();
		// cout<<"rot: "<<rot<<endl;
		// cout<<"ls: "; for (int i=1; i<=n; ++i) cout<<ls[i]<<' '; cout<<endl;
		// cout<<"rs: "; for (int i=1; i<=n; ++i) cout<<rs[i]<<' '; cout<<endl;
		for (int i=1; i<=n; ++i) {
			add(id[i], id[i]+siz[i]-1, w[i]);
			if (rs[i]) add(id[rs[i]], id[rs[i]]+siz[rs[i]]-1, -w[i]);
		}
		for (int i=1,x,y,v; i<=q; ++i) {
			if (read()&1) {
				x=read(); v=read();
				add(id[x], id[x]+siz[x]-1, v);
				if (rs[x]) add(id[rs[x]], id[rs[x]]+siz[rs[x]]-1, -v);
				if (ls[x]) add(id[ls[x]], id[ls[x]]+siz[ls[x]]-1, v);
				// trdiv::upd(x, x, v);
				// if (ls[x]) trdiv::upd(ls[x], btm[ls[x]], v);
				if (ls[x]) trdiv::upd(x, btm[ls[x]], v);
				else trdiv::upd(x, x, v);
			}
			else {
				x=read(); y=read();
				if (x>y) swap(x, y);
				ll ans=query(id[x])+query(id[y]);
				int t=trdiv::lca(x, y), z=top[t];
				// cout<<"z: "<<z<<endl;
				if (!z) {puts("?"); continue;}
				ans=ans-2*query(id[z])+trdiv::qval(z);
				if (t==y) ans-=trdiv::qval(t);
				printf("%lld\n", ans);
			}
		}
	}
}

signed main()
{
	freopen("set.in", "r", stdin);
	freopen("set.out", "w", stdout);

	n=read(); q=read();
	for (int i=1; i<=n; ++i) a[i]=read();
	for (int i=1; i<=n; ++i) w[i]=read();
	// force::solve();
	task1::solve();

	return 0;
}
posted @ 2022-06-13 17:14  Administrator-09  阅读(3)  评论(0编辑  收藏  举报