题解 第负一题
因为时间分配的问题这题几乎没碰,而且只会 \(O(n^2logn)\) 的线段树做法
yysy,这题正解思路来源其实貌似是链上点分治
实际上有个很SB的 \(n^2\) DP
首先枚举一个左端点 \(i\),对每个左端点令 \(dp[j]\) 为区间 \([i, j]\) 的答案
转移
\[dp[i] = max(dp[i-1], dp[i-2]+dp[i])
\]
就没了
我当时瞎想了个什么「可以有相邻的点都不被选」什么的自以为把它hack了
但实际上它能处理这种情况
然后考虑正解:分治
这题要是挂树上就是典型的点分治
对于一个区间,我们直接统计过区间中点的所有区间的答案
还是用上面的DP,处理 \(f[0/1][i]\) 代表强制不选/选mid,区间 \([i, mid]\) 的答案
\(g[0/1][i]\) 同理,不过是向右处理的
那大致思路是枚举左端点,设法直接算出这个左端点对应的所有跨过mid的右端点的贡献
如果不选mid和mid+1,右边的 \(g[0][i]\) 可以前缀和
接下来考虑mid和mid+1的贡献
它们两个可以让 \(f\) 或 \(g\) 的第一维变成1
所以可以预处理出 \(g[1][i]-g[0][i]\)
对于区间 \([i, j], \ ans=f[0][i]+g[0][j]+max(f[1][i]-f[0][i], g[1][i]-g[0][i], 0)\)
现在问题转化为「给定一个数组,每次将所有元素与一个给定值取max并求和」
因为数组固定且询问范围固定,可以直接将数组排序,预处理前缀和
每次二分出第一个比给定值小的位置,用下标计算有多少个数被取max替代了及其贡献
没有被替代的数的和直接用前缀和求就行了
复杂度 \(O(nlog^2n)\),瓶颈在于排序(二分可以用双指针替代)
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f3f3f3f
#define N 200010
#define ll long long
#define reg register int
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n;
ll a[N];
const ll mod=998244353;
inline void md(ll& a, ll b) {a+=b; a=a>=mod?a-mod:a;}
namespace force{
ll ans;
void solve() {
for (int i=1; i<=n; ++i) {
for (int j=i; j<=n; ++j) {
int len=j-i+1, lim=1<<len;
ll tem=0;
for (int s=1; s<lim; ++s) {
ll sum=0;
if (s&(s>>1) || s&(s<<1)) continue;
for (int k=0; k<len; ++k) if (s&(1<<k)) {
sum+=a[i+k];
}
tem=max(tem, sum);
}
md(ans, tem%mod);
}
}
printf("%lld\n", ans);
exit(0);
}
}
namespace task{
ll f[2][N], g[2][N], sta[N], top, ans, sum, sum2[N];
void calc(int l, int r) {
//cout<<"calc: "<<l<<' '<<r<<' '<<ans<<endl;
if (l==r) {ans=(ans+a[l])%mod; return ;}
int mid=(l+r)>>1; top=0; sum=0;
for (reg i=l; i<=r; ++i) f[0][i]=0, f[1][i]=0, g[0][i]=0, g[1][i]=0;
f[0][mid-1]=a[mid-1]; f[1][mid-1]=f[1][mid]=a[mid];
for (int i=mid-2; i>=l; --i) {
f[0][i]=max(f[0][i+1], f[0][i+2]+a[i]);
f[1][i]=max(f[1][i+1], f[1][i+2]+a[i]);
}
g[1][mid+1]=a[mid+1]; sta[++top]=a[mid+1];
if (mid+2<=r) {g[0][mid+2]=a[mid+2]; g[1][mid+2]=a[mid+1]; sum=a[mid+2]; sta[++top]=a[mid+1]-a[mid+2];}
for (reg i=mid+3; i<=r; ++i) {
g[0][i]=max(g[0][i-1], g[0][i-2]+a[i]);
g[1][i]=max(g[1][i-1], g[1][i-2]+a[i]);
sta[++top]=g[1][i]-g[0][i];
sum=(sum+g[0][i])%mod;
}
sort(sta+1, sta+top+1, [](int a, int b) {return a>b;});
#if 0
cout<<"f0: "; for (int i=1; i<=n; ++i) cout<<f[0][i]<<' '; cout<<endl;
cout<<"f1: "; for (int i=1; i<=n; ++i) cout<<f[1][i]<<' '; cout<<endl;
cout<<"g0: "; for (int i=1; i<=n; ++i) cout<<g[0][i]<<' '; cout<<endl;
cout<<"g1: "; for (int i=1; i<=n; ++i) cout<<g[1][i]<<' '; cout<<endl;
cout<<"sta: "; for (int i=1; i<=top; ++i) cout<<sta[i]<<' '; cout<<endl;
#endif
//cout<<"top: "<<top<<endl<<endl;
for (reg i=1; i<=top; ++i) sum2[i]=(sum2[i-1]+sta[i])%mod;
for (reg i=l; i<=mid; ++i) {
//cout<<"i: "<<i<<endl;
ll dlt=max(f[1][i]-f[0][i], 0ll);
//cout<<"dlt: "<<dlt<<endl;
int l=1, r=top, mid;
while (l<=r) {
mid=(l+r)>>1;
if (sta[mid]>=dlt) l=mid+1;
else r=mid-1;
}
//cout<<"l: "<<l-1<<endl;
ans=((ans+sum+sum2[l-1])%mod+(top-l+1)*dlt%mod+f[0][i]*top%mod)%mod;
//cout<<"ans: "<<ans<<endl;
//cout<<endl;
}
//cout<<"return "<<ans<<endl;
calc(l, mid); calc(mid+1, r);
}
void solve() {
calc(1, n);
printf("%lld\n", ans);
exit(0);
}
}
signed main()
{
n=read();
for (int i=1; i<=n; ++i) a[i]=read();
//force::solve();
task::solve();
return 0;
}