P10342 [THUSC 2019] 数列 题解
形式化题面:
求
其中 \(f(l,r)\) 为 \(a_l,...,a_r\) 中有多少个不同的数字。
注意到,除了 Sub2,其余数据点都有 \(\max f\le 800\),这启发我们考虑 \(O(nm)\) 的算法。
套路地,扫描线枚举右端点,则现在只需要考虑其对所有左端点的贡献。
设 \(pre_i\) 表示 \(a_i\) 上一次出现的位置,维护一个 ODT 状物,即所有 \(f\) 的连续段。每次 \(r-1\to r\) 就相当于在 \(pre_{r}+1\) 处 split 一下,后面位置的 \(f\) 值 \(+1\),然后新 push 进去一个 \(([r,r],1)\) 的段,最后合并一些段。注意到只会有 \(O(m)\) 段,所以可以直接用数组维护,每次直接重构都是可以接受的。
然后注意到,只有每一段的右端点才有可能贡献到答案。记这些点为“关键点”。
考虑关键点 \(i\) 对答案的贡献:\((i-l+1)\times f\)。其中 \(f\) 是定值。拆项可得 \(-f\cdot l+(i+1)f\)。不难发现这是一个一次函数的形式。
考虑从左到右加入关键点,那么可以注意到每次加入的一次函数斜率递增,那么其一定会更新一段后缀的答案。考虑将每个位置的最优点描出来,不难发现其构成了一个下凸壳。于是插入线段也是简单的,先 pop 掉那些被完全覆盖的线段,然后 \(O(1)\) 求出两条线段的交点即可。最后再扫描凸壳计算答案即可。这形如若干等差数列求和,容易 \(O(1)\) 计算。
注意到每条线段只会被 pop 一次,且求交点复杂度为 \(O(1)\),所以总的时间复杂度为 \(O(nm)\)。
至于 Sub2,根据基础不等式知识不难发现 \(i\) 取区间中点最优,因此直接枚举区间长度计算即可,复杂度 \(O(n)\)。
代码:
#include<bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/tree_policy.hpp>
#include<ext/pb_ds/hash_policy.hpp>
#define gt getchar
#define pt putchar
#define fst first
#define scd second
#define SZ(s) ((int)s.size())
#define all(s) s.begin(),s.end()
#define pb push_back
#define eb emplace_back
typedef long long ll;
typedef double db;
typedef long double ld;
typedef unsigned long long ull;
typedef unsigned int uint;
const int N=1e5+5;
const int mod=998244353;
using namespace std;
using namespace __gnu_pbds;
typedef pair<int,int> pii;
template<class T,class I> inline void chkmax(T &a,I b){a=max(a,(T)b);}
template<class T,class I> inline void chkmin(T &a,I b){a=min(a,(T)b);}
inline bool __(char ch){return ch>=48&&ch<=57;}
template<class T> inline void read(T &x){
x=0;bool sgn=0;static char ch=gt();
while(!__(ch)&&ch!=EOF) sgn|=(ch=='-'),ch=gt();
while(__(ch)) x=(x<<1)+(x<<3)+(ch&15),ch=gt();
if(sgn) x=-x;
}
template<class T,class ...I> inline void read(T &x,I &...x1){
read(x);
read(x1...);
}
template<class T> inline void print(T x){
static char stk[70];short top=0;
if(x<0) pt('-');
do{stk[++top]=x>=0?(x%10+48):(-(x%10)+48),x/=10;}while(x);
while(top) pt(stk[top--]);
}
template<class T> inline void printsp(T x){
print(x);
putchar(' ');
}
template<class T> inline void println(T x){
print(x);
putchar('\n');
}
int n,a[N],pre[N],pos[N],siz;
struct Seg{
int l,r,w;
Seg(int _l=0,int _r=0,int _w=0)
:l(_l),r(_r),w(_w)
{}
}odt[N];
inline bool in(int x,Seg seg){
return seg.l<=x&&x<=seg.r;
}
inline void split(int x){
auto upd=[&](int i){
if(odt[i+1].w==odt[i].w){
odt[i].r=odt[i+1].r;
for(int j=i+1;j<siz;++j) odt[j]=odt[j+1];
siz--;
}
};
for(int i=1;i<=siz;++i){
if(in(x,odt[i])){
if(x==odt[i].l){
for(int j=i;j<=siz;++j) odt[j].w++;
upd(i-1);
}else{
odt[++siz]=Seg(x,odt[i].r,odt[i].w);
odt[i].r=x-1;
for(int j=i+1;j<=siz;++j) odt[j].w++;
for(int j=siz;j>=i+2;--j) swap(odt[j],odt[j-1]);
upd(i);
}
return;
}
}
}
struct func{
int k,b;
func(int _k=0,int _b=0):k(_k),b(_b){}
inline int get(int x){
return k*x+b;
}
};
inline int cross(func a,func b){
// find min x, so that a.get(x) > b.get(x).
if(a.b>b.b) return 1;
int now=(b.b-a.b)/(a.k-b.k);
while(a.get(now)<=b.get(now)) now++;
while(a.get(now-1)>b.get(now-1)) now--;
return now;
}
inline func gen(int i,int w){
// (i-l+1)*w.
// (i+1)*w - l*w.
return func(-w,(i+1)*w);
}
inline int s(int x){
return (1ll*x*(x+1)/2)%mod;
}
struct Node{
func w;
int l,r;
Node(func _w=func(),int _l=0,int _r=0)
:w(_w),l(_l),r(_r)
{}
inline int val(){
return (((1ll*w.b*(r-l+1)%mod+1ll*w.k*(s(r)-s(l-1)+mod)%mod)%mod)+mod)%mod;
}
}conv[N];
int top;
inline int find(func w){
// find min l, so that w.get(l) > others.
while(top&&conv[top].w.get(conv[top].l)<=w.get(conv[top].l)) top--;
if(!top) return 1;
int x=cross(w,conv[top].w);
conv[top].r=x-1;
return x;
}
inline void add(int &a,int b){
a+=b;
if(a>=mod) a-=mod;
}
namespace corner_case{
bool vis[N];
inline bool check(){
int cnt=0;
for(int i=1;i<=n;++i){
if(!vis[a[i]]) cnt++;
vis[a[i]]=1;
}
return cnt>800;
}
inline void solve(){
int ans=0;
for(int i=1;i<=n;++i){
int len1=(i+1)/2,len2=i/2+1;
if(i&1) add(ans,1ll*len1*len1%mod*(n-i+1)%mod);
else add(ans,1ll*len1*len2%mod*(n-i+1)%mod);
}
println(ans);
}
}
signed main(){
read(n);
for(int i=1;i<=n;++i){
read(a[i]);
pre[i]=pos[a[i]];
pos[a[i]]=i;
}
if(corner_case::check()) return corner_case::solve(),0;
int ans=0;
for(int r=1;r<=n;++r){
split(pre[r]+1);
odt[++siz]=Seg(r,r,1);
if(odt[siz-1].w==1) odt[siz-1].r=r,siz--;
conv[top=1]=Node(gen(odt[1].r,odt[1].w),odt[1].l,odt[1].r);
for(int i=2;i<=siz;++i){
func qwq=gen(odt[i].r,odt[i].w);
Node now(qwq,odt[i].l,odt[i].r);
now.l=find(qwq);
conv[++top]=now;
}
for(int i=1;i<=top;++i) add(ans,conv[i].val());
}
println(ans);
return 0;
}