「KDOI-06-S」消除序列 题解
分享一个应该很少人想到的做法。
首先贪心地想,第一种操作肯定最多选择一次。比如如果选择了下标 \(i\) 和 \(j\) 进行第一种操作,那么就等价于在 \(\max\{i,j\}\) 进行了一次操作。由于代价是非负数,则我们最多只用执行一次。当然也可以不使用这种操作。
有了这个思路,我们先考虑不使用第一种操作的最优答案,很明显就是 \(\begin{aligned} \sum _{i \notin P}b_i\end{aligned}\)。
接着再考虑使用一次的情况。我们先假定在位置 \(j\) 处进行一次第一种操作,这就表示 \([1,j]\) 的所有数字都变成了 \(0\),我们将 \(i\in P \wedge i\in [1,j]\) 的所有 \(i\) 统计到集合 \(A\) 中,将 \(i \notin P \wedge i \notin [1,j]\) 的所有 \(i\) 统计到集合 \(B\) 中。这个时候不难发现 \(A\) 中的元素都是应当为 \(1\) 但是因为在 \(j\) 进行了第一种操作后变成 \(0\) 的下标,那么我们就要把它们全部再变为 \(1\);而在 \(B\) 中的元素则是不应当为 \(1\) 但是还保持初始状态 \(1\) 的下标,我们需要一个一个地将它们变为 \(0\)。则此时答案就是:
但是数据范围不允许我们枚举 \(j\)。其实我们不难发现当 \(j\in [p_i,p_{i+1}-1]\) 时 \(A\) 和 \(B\) 两个集合是不变的。则代入上述式子唯一改变的就是 \(a_j\) 的值。而且题目中提到 \(\sum m\leq 5\times 10^5\),这说明 \(|A|\) 会很小,但是 \(|B|\) 有可能会很大,因此我们提前预处理出 \(\begin{aligned}sum_i=\sum _{k=1}^ib_k \end{aligned}\),则 \(\begin{aligned} \sum_{i\in B}b_i=sum_n-sum_j-num \end{aligned}\),其中 \(\begin{aligned}num=\sum_{k\in P\wedge k\in [j+1,n]}b_k\end{aligned}\)。\(num\) 可以一步一步推得,因此计算量相对减少了很多。我们再代入进原式:
考虑到当前 \(j\) 枚举的区间保证 \(A\) 和 \(B\) 不变,转换一下得到:
不难发现 \(\min\{a_j-sum_j\}\) 中的 \(a_j\) 和 \(sum_j\) 都不会被修改,因此我们可以用 ST 表记录 \(a_j-sum_j\) 的最小值。然后我们就可以在规定的数据范围内通过了。
代码如下:
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int MAXN=5e5+5;
int n,m,q;
int p[MAXN];
int a[MAXN],b[MAXN],c[MAXN];
int sum[MAXN],st[MAXN][20];
int lg[MAXN];
int query(int l,int r)
{
int k=lg[r-l+1];
return min(st[l][k],st[r-(1<<k)+1][k]);
}
void read(int &x)
{
x=0;
short flag=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') flag=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
x*=flag;
}
signed main()
{
read(n);
for(int i=2;i<=n;i++) lg[i]=lg[i>>1]+1;
for(int i=1;i<=n;i++)
{
for(int j=0;j<20;j++) st[i][j]=1e15;
}
for(int i=1;i<=n;i++) read(a[i]);
for(int i=1;i<=n;i++) read(b[i]),sum[i]=sum[i-1]+b[i];
for(int i=1;i<=n;i++) read(c[i]);
for(int i=1;i<=n;i++) st[i][0]=a[i]-sum[i];
for(int j=1;j<=19;j++)
{
for(int i=1;i+(1<<j)-1<=n;i++) st[i][j]=min(st[i][j-1],st[i+(1<<(j-1))][j-1]);
}
read(q);
while(q--)
{
read(m),p[m+1]=n+1;
for(int i=1;i<=m;i++) read(p[i]);
int minn=sum[n],num=0,rsum=0,minx=n+1;
for(int i=1;i<=m;i++) minn-=b[p[i]],rsum+=b[p[i]],minx=min(minx,p[i]);
if(minx>1) minn=min(minn,query(1,minx-1)+sum[n]-rsum);
for(int i=1;i<=m;i++) num+=c[p[i]],rsum-=b[p[i]],minn=min(minn,query(p[i],p[i+1]-1)+sum[n]+num-rsum);
cout<<minn<<endl;
}
return 0;
}