天依给了你一个函数然后让你求矩阵函数和

不会线段树历史版本区间和,学了好几个小时再加因为数组开小的几个小时的虚空调试终于过了

这篇题解会更详细的讲解如何实现线段树历史版本求和。

也推荐一篇讲解线段树历史版本区间和的博客,我是从这里面学的。


观察函数 \(g(i, j)\) 的定义,会发现 \(j\) 这个右端点是不动的,于是考虑右端点确定的时候左端点(也就是答案)在哪里。

如果我们从第 \(j\) 个数开始从后往前数,计算每种数的出现个数,那么可以发现左端点一定是落在 \(i\) 个数右边(包含)第一个出现次数为 \(1\) 的位置

知道了这一点再来考虑把第 \(j + 1\) 个数加进来的时候的变化,很明显第 \(j + 1\) 个数只会引起上一个 \(a_{k} = a_{j + 1}\) 的位置 \(k\) 的变化,于是记录一个 \(last_{i}\) 表示数 \(i\) 上一次出现的位置在哪里。

加入了 \(a_{j + 1}\) 后,\(a_{last_{a_{j + 1}}}\) 的出现次数也就从 \(1\) 变到了 \(2\),那么 \(last_{a_{j + 1}}\) 前面的 \(g(l, j)\) 也会变化,具体来说在 \(k\) 的上一个出现次数为 \(1\) 的位置到 \(last_{a_{j + 1}}\) 这一段的函数值都会变化。这样我们可以用线段树区间推平来维护函数值,而记录前一个和后一个出现次数为 \(1\) 的位置可以使用链表维护。

但是区间推平的标记并不如区间加法容易记录和计算,所以我们维护 \(g(i, j) - i\) 的值,这样我们就可以把区间赋值变成区间加法了。

(上面虽然说是维护 \(g(i, j) - i\) 的值,但是实际上在代码中我为了偷懒会把 \(i\) 加进去,这一点其实无伤大雅。)

如果我们把 \(j = 1\) 时的函数值写下来,排成一行,再把 \(j = 2\) 时的函数值写下来,排成新的一行,以此类推,那么题目其实就是要我们矩阵求和。

但这样是不可做的,我们没有注意到第 \(i\) 行是由第 \(i - 1\) 行推过来的这一性质。如果把每一行看作线段树的一个版本,那么这道题是让我们做线段树历史版本区间求和。

因为这道题不要求强制在线,那么可以先把询问离线下来,差分成 \(\sum\limits_{j = 1}^{x - 1}\sum\limits_{i = l}^{r}g(i, j)\)\(\sum\limits_{j = 1}^{y}\sum\limits_{i = l}^{r}g(i, j)\),按照所需的版本编号排序就可以了。

如果你会这个东西那么接下来的讲解就可以不用看了,如果不会的话那么建议你继续看下去。

暴力保存和查询每个版本的线段树慢的原因在于:直接保存了每个函数值最终的结果,没有利用到区间加法可以打标记维护的性质。

考虑一次单点加法对某个函数值以及接下来的一段时间的影响:假设这次加法的时间为 \(t\),加的值为 \(v\),当前时间为 \(tim\),那么这次加法对于 \(t \sim tim\) 这段时间的每个线段树都有贡献,每个贡献都是 \(v\),那么对这一段版本和的贡献就是 \((tim - t)v\)

拓展到区间加法也是一样的道理:只需要再乘上一个区间长度即可。

那么再拓展到有许多次加法操作,那么对于 \(tim\) 及其之前的历史版本和的贡献就是:

\[\sum(tim - t_{i} + 1)v_{i} \]

对于这种式子有一个很常见的套路就是用乘法分配律将式子展开分别维护,于是贡献就是:

\[(tim + 1)\sum v_{i} - \sum t_{i}v_{i} \]

(上面是对于单点的情况,对于区间那么在外面乘上区间长度就好。)

于是分别维护某个区间的标记 \(\sum v_{i}\)\(\sum t_{i}v_{i}\) 即可,把这两个标记分别称为 \(add\)\(del\),查询的时候计算 \((tim + 1)add - del\) 即可。

同时也要维护一个区间的和,也就是线段树递归到询问区间完全包含当前区间时直接返回,这样才能保证时间复杂度正确。

已有的标记不太好直接维护这个,那么直接再开一些标记: \(sum\) 表示当前版本的区间和,\(t\) 表示这个区间上次更新历史版本区间和在什么时候,\(sumh\) 表示这个区间的历史版本区间和。(要注意区分这里的 \(t\) 标记和上面的修改时间 \(t\)。)

一个区间在要被更新的时候,计算这一段版本的区间和(因为在这一段版本中它没有被修改过,而在接下来它马上就要被修改了),更新 \(sumh\)\(t\)。具体地,令 \(sumh\) 加上 \((tim - t)sum\),令 \(t\) 设为 \(tim\),实际上是把这一段历史版本区间和加进到了 \(sumh\) 里以后就可以不管了,可以避免许多复杂麻烦的标记合并。

具体可以看代码实现(使用了标记永久化):

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n, q, l, r, x, y, cq, tim, now, pos, a[1000005], last[1000005], lt[1000005], rt[1000005];
ll ans[1000005];
struct query {
	int l, r, x, v, id;
	query(int L = 0, int R = 0, int X = 0, int V = 0, int Id = 0): l(L), r(R), x(X), v(V), id(Id) {}
	bool operator < (const query& _) const {
		return x < _.x;
	}
} qu[2000005];
struct Segment_Tree {
	struct segment {
		int l, r;
		ll sum, sumh, t, add, del;
		void operator += (const ll& _) {
			add += _, del += tim * _;
			sumh += (tim - t) * sum;
			sum += (r - l + 1) * _;
			t = tim;
		}
	} t[4000005];
	#define lc (k << 1)
	#define rc (lc | 1)
	#define mid ((t[k].l + t[k].r) >> 1)
	void build(int k) {
		if(t[k].l == t[k].r) return;
		t[lc].l = t[k].l, t[lc].r = mid, t[rc].l = mid + 1, t[rc].r = t[k].r;
		build(lc), build(rc);
	}
	void change(int k, const int& L, const int& R, const ll& v) {
		if(L <= t[k].l && t[k].r <= R) {
			t[k] += v;
			return;
		}
		t[k].sumh += t[k].sum * (tim - t[k].t);
		t[k].sum += (min(R, t[k].r) - max(L, t[k].l) + 1) * v;
		t[k].t = tim;
		if(L <= mid) change(lc, L, R, v);
		if(R > mid) change(rc, L, R, v);
	}
	ll ask(int k, const int& L, const int& R) {
		const ll len = min(R, t[k].r) - max(L, t[k].l) + 1;
		ll ret = t[k].add * len * (tim + 1) - t[k].del * len;
		if(L <= t[k].l && t[k].r <= R) return t[k].sumh + t[k].sum * (tim - t[k].t + 1);
		if(L <= mid) ret += ask(lc, L, R);
		if(R > mid) ret += ask(rc, L, R);
		return ret;
	}
} tree;
int main() {
	ios::sync_with_stdio(0);
	cin.tie(0), cout.tie(0);
	cin >> n >> q;
	for(int i = 1; i <= n; ++i) {
		cin >> a[i];
		lt[i] = i - 1, rt[i] = i + 1;
	}
	for(int i = 1; i <= q; ++i) {
		cin >> l >> r >> x >> y;
		if(x > 1) qu[++cq] = query(l, r, x - 1, -1, i);
		qu[++cq] = query(l, r, y, 1, i);
	}
	stable_sort(qu + 1, qu + 1 + cq);
	tree.t[1].l = 1, tree.t[1].r = n;
	tree.build(1);
	for(tim = 1, now = 1; tim <= n; ++tim) {
		pos = last[a[tim]];
		if(pos) {
			tree.change(1, lt[pos] + 1, pos, rt[pos] - pos);
			rt[lt[pos]] = rt[pos], lt[rt[pos]] = lt[pos];
		}
		last[a[tim]] = tim;
		tree.change(1, tim, tim, tim);
		while(now <= cq && qu[now].x == tim) {
			ans[qu[now].id] += qu[now].v * tree.ask(1, qu[now].l, qu[now].r);
			++now;
		}
	}
	for(int i = 1; i <= q; ++i) cout << ans[i] << '\n';
	return 0;
}

诺艾尔小姐提供了一种树状数组做法,可是我不会,所以丢一个代码在这里:

//by TulipeNoire
//Luogu uid 407223
//https://www.luogu.com.cn/record/146683264
#include<bits/stdc++.h>
#define lowbit(x) ((x)&-(x))
using namespace std;
using LL=long long;
using uLL=unsigned long long;
template<typename T>inline void read(T&t){
    t=0;
    char c=getchar();
    int f=1;
    while (!isdigit(c)) f=c=='-'?-f:f,c=getchar();
    while (isdigit(c)) t=(t<<3)+(t<<1)+c-'0',c=getchar();
    t*=f;
    return;
}
template<typename T,typename...Args>inline void read(T&t,Args&...args){read(t),read(args...);}
const int N=1000005;
int n,m,a[N],lst[N],tag[N];
int pcnt,qcnt;
struct BIT {
    uLL dat[N];
    inline void add(int x,uLL v) {
        while (x<=n) dat[x]+=v,x+=lowbit(x);
    }
    inline uLL get(int x) {
        uLL res=0;
        while (x) res+=dat[x],x-=lowbit(x);
        return res;
    }
}T[4];
struct Modify {
    int a,b,v;
    inline bool operator<(const Modify o) {
        return a<o.a;
    }
}p[N<<3];
struct Query {
    int a,b,f,id;
    inline bool operator<(const Query o) {
        return a<o.a;
    }
}q[N<<2];
uLL ans[N];
inline void Insert(int al,int ar,int tl,int tr,int v) {
    // printf("*%d %d %d %d %d\n",al,ar,tl,tr,v);
    p[++pcnt]={al,tl,v};
    p[++pcnt]={ar+1,tl,-v};
    p[++pcnt]={al,tr+1,-v};
    p[++pcnt]={ar+1,tr+1,v};
}
set<int>s;
inline void Delete(int x,int t) {
    auto it=s.find(x);
    Insert(it==s.begin()?1:(*prev(it))+1,x,tag[x],t-1,x);
    Insert(x+1,*next(it),tag[*next(it)],t-1,*next(it));
    tag[*next(it)]=t;
    s.erase(it);
}
int main() {
    read(n,m);
    for (int i=1;i<=n;i++) read(a[i]);
    for (int i=1;i<=n;i++) {
        s.insert(i),tag[i]=i;
        if (lst[a[i]]) Delete(lst[a[i]],i);
        lst[a[i]]=i;
    }
    int pre=0;
    for (auto x:s) {
        Insert(pre+1,x,tag[x],n,x);
        pre=x;
    }
    for (int i=1;i<=m;i++) {
        int al,ar,tl,tr;
        read(al,ar,tl,tr);
        q[++qcnt]={ar,tr,1,i};
        q[++qcnt]={al-1,tr,-1,i};
        q[++qcnt]={ar,tl-1,-1,i};
        q[++qcnt]={al-1,tl-1,1,i};
    }
    sort(p+1,p+pcnt+1);
    sort(q+1,q+qcnt+1);
    int idx=1;
    for (int i=1;i<=qcnt;i++) {
        while (idx<=pcnt&&p[idx].a<=q[i].a) {
            int a=p[idx].a,b=p[idx].b,v=p[idx].v;
            T[0].add(b,v);
            T[1].add(b,1ll*b*v);
            T[2].add(b,1ll*a*v);
            T[3].add(b,1ll*a*b*v);
            idx++;
        }
        uLL res=0;
        int A=q[i].a+1,B=q[i].b+1;
        ans[q[i].id]+=1ll*A*B*T[0].get(q[i].b)*q[i].f;
        ans[q[i].id]-=1ll*A*T[1].get(q[i].b)*q[i].f;
        ans[q[i].id]-=1ll*B*T[2].get(q[i].b)*q[i].f;
        ans[q[i].id]+=T[3].get(q[i].b)*q[i].f;
    }
    for (int i=1;i<=m;i++) printf("%lld\n",(LL)ans[i]);
}
posted @ 2023-12-22 21:55  A_box_of_yogurt  阅读(4)  评论(0编辑  收藏  举报  来源
Document