题解 签到题
一道完整的签到题应该包括一个小时的读错题和一个小时的写代码(
建出笛卡尔树(建普通树也行)
尝试对每个点维护出这个点到根的路径上所有右父亲的答案
查询的时候最优的那个什么点就是 lca 的第一个右父亲
然后差分相减即可
发现差分相减的时候需要用到一个点的精确权值
所以还要维护一下
修改的时候涉及到一个链加所以要树剖一下
复杂度是 \(O(n\log^2 n)\) 的,通过卡常完成了对 \(O(n\log n)\) 做法的一个吊打
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 200010
#define fir first
#define sec second
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, q;
ll w[N];
int a[N];
namespace force{
int cnt[N];
set<int> getset(int x) {
set<int> ans;
int pre=x;
ans.insert(x);
for (int i=x+1; i<=n; ++i) if (a[i]>a[pre])
ans.insert(i), pre=i;
return ans;
}
void solve() {
for (int i=1,x,y,v; i<=q; ++i) {
if (read()&1) {
x=read(); v=read();
for (int y=1; y<=n; ++y) {
set<int> tem=getset(y);
while (tem.size()>2) tem.erase(tem.find(*tem.rbegin()));
if (tem.find(x)!=tem.end()) w[y]+=v;
}
}
else {
x=read(); y=read();
for (int i=1; i<=n; ++i) cnt[i]=0;
for (int i=max(x, y)+1; i<=n; ++i) ++cnt[i];
set<int> tem;
tem=getset(x); for (auto it:tem) ++cnt[it];
tem=getset(y); for (auto it:tem) ++cnt[it];
for (int z=1; z<=n; ++z) if (cnt[z]==3) {
tem=getset(x);
for (auto it:getset(y)) tem.insert(it);
while (tem.size() && *tem.rbegin()>z) tem.erase(tem.find(*tem.rbegin()));
ll ans=0;
for (auto it:tem) ans+=w[it];
printf("%lld\n", ans);
goto jump;
}
printf("?\n");
jump: ;
}
}
}
}
namespace task1{
ll bit[N];
pair<int, int> sta[N];
int ls[N], rs[N], id[N], siz[N], lg[N], dep[N], fa[N], top[N], btm[N], mson[N], stop, rot, tot;
inline void add(int i, ll dat) {for (; i<=n; i+=i&-i) bit[i]+=dat;}
inline void add(int l, int r, int dat) {
++r;
while (l<r) bit[l]+=dat, l+=l&-l;
while (r<l) bit[r]-=dat, r+=r&-r;
}
inline ll query(int i) {ll ans=0; for (; i; i-=i&-i) ans+=bit[i]; return ans;}
void dfs1(int u, int t) {
top[u]=t; id[u]=++tot; siz[u]=1; btm[u]=u;
// for (int i=1; dep[u]>=1<<i; ++i)
// fa[i][u]=fa[i-1][fa[i-1][u]];
if (ls[u]) dep[ls[u]]=dep[u]+1, fa[ls[u]]=u, dfs1(ls[u], u), siz[u]+=siz[ls[u]];
if (rs[u]) dep[rs[u]]=dep[u]+1, fa[rs[u]]=u, dfs1(rs[u], t), siz[u]+=siz[rs[u]], btm[u]=btm[rs[u]];
mson[u]=siz[ls[u]]>siz[rs[u]]?ls[u]:rs[u];
}
// int lca(int a, int b) {
// if (dep[a]<dep[b]) swap(a, b);
// while (dep[a]>dep[b]) a=fa[lg[dep[a]-dep[b]]-1][a];
// if (a==b) return a;
// for (int i=lg[dep[a]]-1; ~i; --i)
// if (fa[i][a]!=fa[i][b])
// a=fa[i][a], b=fa[i][b];
// return fa[0][a];
// }
namespace trdiv{
ll bit[N];
int top[N], id[N], tot;
inline void add(int i, ll dat) {for (; i<=n; i+=i&-i) bit[i]+=dat;}
inline void add(int l, int r, int dat) {
++r;
while (l<r) bit[l]+=dat, l+=l&-l;
while (r<l) bit[r]-=dat, r+=r&-r;
}
inline ll query(int i) {ll ans=0; for (; i; i-=i&-i) ans+=bit[i]; return ans;}
void dfs2(int u, int t) {
top[u]=t; id[u]=++tot;
if (!mson[u]) return ;
dfs2(mson[u], t);
if (ls[u]&&ls[u]!=mson[u]) dfs2(ls[u], ls[u]);
if (rs[u]&&rs[u]!=mson[u]) dfs2(rs[u], rs[u]);
}
void init() {for (int i=1; i<=n; ++i) add(id[i], id[i], w[i]);}
void upd(int x, int y, ll val) {
// cout<<"upd: "<<x<<' '<<y<<' '<<val<<endl;
while (top[x]!=top[y]) {
if (dep[top[x]]<dep[top[y]]) swap(x, y);
add(id[top[x]], id[x], val);
x=fa[top[x]];
}
if (dep[x]>dep[y]) swap(x, y);
add(id[x], id[y], val);
}
int lca(int x, int y) {
// cout<<"upd: "<<x<<' '<<y<<' '<<val<<endl;
while (top[x]!=top[y]) {
if (dep[top[x]]<dep[top[y]]) swap(x, y);
x=fa[top[x]];
}
return dep[x]<dep[y]?x:y;
}
ll qval(int x) {return query(id[x]);}
}
void solve() {
// for (int i=1; i<=n; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
for (int i=1; i<=n; ++i) {
pair<int, int> now={i, a[i]};
int k=stop;
while (k && sta[k].sec<now.sec) --k;
if (k) rs[sta[k].fir]=now.fir;
if (k!=stop) ls[now.fir]=sta[k+1].fir;
sta[stop=++k]=now;
}
rot=sta[1].fir;
dep[rot]=1; dfs1(rot, 0);
trdiv::dfs2(rot, rot); trdiv::init();
// cout<<"rot: "<<rot<<endl;
// cout<<"ls: "; for (int i=1; i<=n; ++i) cout<<ls[i]<<' '; cout<<endl;
// cout<<"rs: "; for (int i=1; i<=n; ++i) cout<<rs[i]<<' '; cout<<endl;
for (int i=1; i<=n; ++i) {
add(id[i], id[i]+siz[i]-1, w[i]);
if (rs[i]) add(id[rs[i]], id[rs[i]]+siz[rs[i]]-1, -w[i]);
}
for (int i=1,x,y,v; i<=q; ++i) {
if (read()&1) {
x=read(); v=read();
add(id[x], id[x]+siz[x]-1, v);
if (rs[x]) add(id[rs[x]], id[rs[x]]+siz[rs[x]]-1, -v);
if (ls[x]) add(id[ls[x]], id[ls[x]]+siz[ls[x]]-1, v);
// trdiv::upd(x, x, v);
// if (ls[x]) trdiv::upd(ls[x], btm[ls[x]], v);
if (ls[x]) trdiv::upd(x, btm[ls[x]], v);
else trdiv::upd(x, x, v);
}
else {
x=read(); y=read();
if (x>y) swap(x, y);
ll ans=query(id[x])+query(id[y]);
int t=trdiv::lca(x, y), z=top[t];
// cout<<"z: "<<z<<endl;
if (!z) {puts("?"); continue;}
ans=ans-2*query(id[z])+trdiv::qval(z);
if (t==y) ans-=trdiv::qval(t);
printf("%lld\n", ans);
}
}
}
}
signed main()
{
freopen("set.in", "r", stdin);
freopen("set.out", "w", stdout);
n=read(); q=read();
for (int i=1; i<=n; ++i) a[i]=read();
for (int i=1; i<=n; ++i) w[i]=read();
// force::solve();
task1::solve();
return 0;
}