题解 第负一题

传送门

因为时间分配的问题这题几乎没碰,而且只会 \(O(n^2logn)\) 的线段树做法
yysy,这题正解思路来源其实貌似是链上点分治

实际上有个很SB的 \(n^2\) DP
首先枚举一个左端点 \(i\),对每个左端点令 \(dp[j]\) 为区间 \([i, j]\) 的答案
转移

\[dp[i] = max(dp[i-1], dp[i-2]+dp[i]) \]

就没了
我当时瞎想了个什么「可以有相邻的点都不被选」什么的自以为把它hack了
但实际上它能处理这种情况

然后考虑正解:分治
这题要是挂树上就是典型的点分治
对于一个区间,我们直接统计过区间中点的所有区间的答案
还是用上面的DP,处理 \(f[0/1][i]\) 代表强制不选/选mid,区间 \([i, mid]\) 的答案
\(g[0/1][i]\) 同理,不过是向右处理的
那大致思路是枚举左端点,设法直接算出这个左端点对应的所有跨过mid的右端点的贡献
如果不选mid和mid+1,右边的 \(g[0][i]\) 可以前缀和
接下来考虑mid和mid+1的贡献
它们两个可以让 \(f\)\(g\) 的第一维变成1
所以可以预处理出 \(g[1][i]-g[0][i]\)
对于区间 \([i, j], \ ans=f[0][i]+g[0][j]+max(f[1][i]-f[0][i], g[1][i]-g[0][i], 0)\)
现在问题转化为「给定一个数组,每次将所有元素与一个给定值取max并求和」
因为数组固定且询问范围固定,可以直接将数组排序,预处理前缀和
每次二分出第一个比给定值小的位置,用下标计算有多少个数被取max替代了及其贡献
没有被替代的数的和直接用前缀和求就行了
复杂度 \(O(nlog^2n)\),瓶颈在于排序(二分可以用双指针替代)

Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f3f3f3f
#define N 200010
#define ll long long
#define reg register int
//#define int long long 

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n;
ll a[N];
const ll mod=998244353;
inline void md(ll& a, ll b) {a+=b; a=a>=mod?a-mod:a;}

namespace force{
	ll ans;
	void solve() {
		for (int i=1; i<=n; ++i) {
			for (int j=i; j<=n; ++j) {
				int len=j-i+1, lim=1<<len;
				ll tem=0;
				for (int s=1; s<lim; ++s) {
					ll sum=0;
					if (s&(s>>1) || s&(s<<1)) continue;
					for (int k=0; k<len; ++k) if (s&(1<<k)) {
						sum+=a[i+k];
					}
					tem=max(tem, sum);
				}
				md(ans, tem%mod);
			}
		}
		printf("%lld\n", ans);
		exit(0);
	}
}

namespace task{
	ll f[2][N], g[2][N], sta[N], top, ans, sum, sum2[N];
	void calc(int l, int r) {
		//cout<<"calc: "<<l<<' '<<r<<' '<<ans<<endl;
		if (l==r) {ans=(ans+a[l])%mod; return ;}
		int mid=(l+r)>>1; top=0; sum=0;
		for (reg i=l; i<=r; ++i) f[0][i]=0, f[1][i]=0, g[0][i]=0, g[1][i]=0;
		
		f[0][mid-1]=a[mid-1]; f[1][mid-1]=f[1][mid]=a[mid];
		for (int i=mid-2; i>=l; --i) {
			f[0][i]=max(f[0][i+1], f[0][i+2]+a[i]);
			f[1][i]=max(f[1][i+1], f[1][i+2]+a[i]);
		}
		g[1][mid+1]=a[mid+1]; sta[++top]=a[mid+1];
		if (mid+2<=r) {g[0][mid+2]=a[mid+2]; g[1][mid+2]=a[mid+1]; sum=a[mid+2]; sta[++top]=a[mid+1]-a[mid+2];}
		for (reg i=mid+3; i<=r; ++i) {
			g[0][i]=max(g[0][i-1], g[0][i-2]+a[i]);
			g[1][i]=max(g[1][i-1], g[1][i-2]+a[i]);
			sta[++top]=g[1][i]-g[0][i];
			sum=(sum+g[0][i])%mod;
		}
		sort(sta+1, sta+top+1, [](int a, int b) {return a>b;});
		#if 0
		cout<<"f0: "; for (int i=1; i<=n; ++i) cout<<f[0][i]<<' '; cout<<endl;
		cout<<"f1: "; for (int i=1; i<=n; ++i) cout<<f[1][i]<<' '; cout<<endl;
		cout<<"g0: "; for (int i=1; i<=n; ++i) cout<<g[0][i]<<' '; cout<<endl;
		cout<<"g1: "; for (int i=1; i<=n; ++i) cout<<g[1][i]<<' '; cout<<endl;
		cout<<"sta: "; for (int i=1; i<=top; ++i) cout<<sta[i]<<' '; cout<<endl;
		#endif
		
		//cout<<"top: "<<top<<endl<<endl;
		for (reg i=1; i<=top; ++i) sum2[i]=(sum2[i-1]+sta[i])%mod;
		for (reg i=l; i<=mid; ++i) {
			//cout<<"i: "<<i<<endl;
			ll dlt=max(f[1][i]-f[0][i], 0ll);
			//cout<<"dlt: "<<dlt<<endl;
			int l=1, r=top, mid;
			while (l<=r) {
				mid=(l+r)>>1;
				if (sta[mid]>=dlt) l=mid+1;
				else r=mid-1;
			}
			//cout<<"l: "<<l-1<<endl;
			ans=((ans+sum+sum2[l-1])%mod+(top-l+1)*dlt%mod+f[0][i]*top%mod)%mod;
			//cout<<"ans: "<<ans<<endl;
			//cout<<endl;
		}
		//cout<<"return "<<ans<<endl;
		calc(l, mid); calc(mid+1, r);
	}
	void solve() {
		calc(1, n);
		printf("%lld\n", ans);
		exit(0);
	}
}

signed main()
{
	n=read();
	for (int i=1; i<=n; ++i) a[i]=read();
	//force::solve();
	task::solve();
	
	return 0;
}
posted @ 2021-09-11 07:32  Administrator-09  阅读(10)  评论(0编辑  收藏  举报