题解 技术情报局

传送门

题意即为每个区间的最大值要乘两次
所以放到笛卡尔树上,合并信息的时候根节点的值乘两次即可

Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 10000010
#define ll long long
#define fir first
#define sec second
#define make make_pair
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

ll a[N];
namespace GenHelper {
	unsigned z1, z2, z3, z4, b;
	unsigned rand_() {
		b = ((z1 << 6) ^ z1) >> 13;
		z1 = ((z1 & 4294967294U) << 18) ^ b;
		b = ((z2 << 2) ^ z2) >> 27;
		z2 = ((z2 & 4294967288U) << 2) ^ b;
		b = ((z3 << 13) ^ z3) >> 21;
		z3 = ((z3 & 4294967280U) << 7) ^ b;
		b = ((z4 << 3) ^ z4) >> 12;
		z4 = ((z4 & 4294967168U) << 13) ^ b;
		return (z1 ^ z2 ^ z3 ^ z4);
	}
} // namespace GenHelper
void get (int n, unsigned s, int l, int r) {
	using namespace GenHelper;
	z1 = s;
	z2 = unsigned((~s) ^ 0x233333333U);
	z3 = unsigned(s ^ 0x1234598766U);
	z4 = (~s) + 51;
	for (int i = 1; i <= n; i ++) {
		int x = rand_() & 32767;
		int y = rand_() & 32767;
		a[i] = l + (x * 32768 + y) % (r - l + 1);
	}
}

int n, s, l, r;
ll mod;

namespace force{
	ll ans;
	ll calc(int l, int r) {
		ll ans=1, maxn=0;
		for (int i=l; i<=r; ++i) {
			ans=ans*a[i]%mod;
			maxn=max(maxn, a[i]);
		}
		return ans*maxn%mod;
	}
	void solve() {
		for (int i=1; i<=n; ++i) for (int j=i; j<=n; ++j) ans=(ans+calc(i, j))%mod;
		printf("%lld\n", ans);
		exit(0);
	}
}

namespace task1{
	int ls[N], rs[N], top;
	pair<int, int> sta[N];
	ll tl[N], tr[N], prod[N], ans;
	void dfs(int u) {
		ans=(ans+a[u]*a[u])%mod;
		if (ls[u]&&rs[u]) {
			dfs(ls[u]); dfs(rs[u]);
			ans=(ans+tl[ls[u]]*a[u]%mod*tr[rs[u]]%mod*a[u]%mod)%mod;
			ans=(ans+tl[ls[u]]*a[u]%mod*a[u]%mod)%mod;
			ans=(ans+a[u]*tr[rs[u]]%mod*a[u]%mod)%mod;
			tl[u]=(tl[ls[u]]*a[u]%mod*prod[rs[u]]%mod + a[u]*prod[rs[u]]%mod + tl[rs[u]])%mod;
			tr[u]=(tr[ls[u]] + prod[ls[u]]*a[u]%mod + prod[ls[u]]*a[u]%mod*tr[rs[u]]%mod)%mod;
			prod[u]=prod[ls[u]]*a[u]%mod*prod[rs[u]]%mod;
		}
		else if (ls[u]) {
			dfs(ls[u]);
			ans=(ans+tl[ls[u]]*a[u]%mod*a[u]%mod)%mod;
			tl[u]=(tl[ls[u]]*a[u]%mod + a[u])%mod;
			tr[u]=(tr[ls[u]] + prod[ls[u]]*a[u]%mod)%mod;
			prod[u]=prod[ls[u]]*a[u]%mod;
		}
		else if (rs[u]) {
			dfs(rs[u]);
			ans=(ans+a[u]*tr[rs[u]]%mod*a[u]%mod)%mod;
			tl[u]=(a[u]*prod[rs[u]]%mod + tl[rs[u]])%mod;
			tr[u]=(a[u] + a[u]%mod*tr[rs[u]]%mod)%mod;
			prod[u]=a[u]*prod[rs[u]]%mod;
		}
		else {
			tl[u]=tr[u]=prod[u]=a[u];
		}
	}
	void solve() {
		// cout<<double(sizeof(a)*4+sizeof(ls)*2+sizeof(sta))/1024/1024<<endl;
		for (int i=1; i<=n; ++i) {
			int k=top;
			while (k && sta[k].sec<a[i]) --k;
			if (k) rs[sta[k].fir]=i;
			if (k<top) ls[i]=sta[k+1].fir;
			sta[++k]=make(i, a[i]);
			top=k;
		}
		dfs(sta[1].fir);
		printf("%lld\n", ans);
		exit(0);
	}
}

signed main()
{
	freopen("tio.in", "r", stdin);
	freopen("tio.out", "w", stdout);

	n=read(); s=read(); l=read(); r=read(); mod=read();
	get(n, s, l, r);
	// cout<<"a: "; for (int i=1; i<=n; ++i) cout<<a[i]<<' '; cout<<endl;
	// force::solve();
	task1::solve();

	return 0;
}
posted @ 2021-11-01 14:29  Administrator-09  阅读(0)  评论(0编辑  收藏  举报