[bzoj4504] k个串 kstring
题目
兔子们在玩k个串的游戏。首先,它们拿出了一个长度为n的数字序列,选出其中的一个连续子串,然后统计其子串中所有数字之和(注意这里重复出现的数字只被统计一次)。
兔子们想知道,在这个数字序列所有连续的子串中,按照以上方式统计其所有数字之和,第k大的和是多少。
题解
首先最简单的一个暴力想法就是枚举每个左端点,再枚举每个右端点,然后计算和并取最大值$O(n^3)$
但这样其实会有大量的重复计算,改进算法就是要尽量去除重复的计算。
考虑优化转移的过程,第一次的扫描(以1为左端点)肯定是无法避免的,我们要$arr[1~n]$保存结果
在计算以2为左端点的时候,观察下图,与第一次的区别就是少了一个1,但只要右端点跨过第二个1,后面的结果就跟第一次一样。
即,对于$l=2,r>=4$的区间,答案直接继承上一次的即可
否则就在原来的基础上减去1.
计算完之后,就把这一次的最大区间堆进一个堆里
之后的第三次,第四次...也是同理
我们可以用主席树来维护每次的结果数组,
即每次新建一个版本,对于要减的区间就进行区间修改,其他位置不变
另,为了节省空间,可以用永久标记,即不往下推,在更新最大值的时候相应地减去这个标记的值
在询问k的时候,每次把堆顶取出,设它是$[l~r]$这段区间
为了不让它再被选到,把l这棵树的第r个位置值设为-inf、
然后再在这颗树上再选一个最大的区间出来,塞进堆里。
反复k次,答案就出来了。
另外,为了知道最大区间的右端点,可以在主席树节点内追加一个pos,标记当前节点的最大值从何而来。
为了知道那一段区间要减去上一位的值,可以处理一个nxt[i]数组记录【第i位的值】下一次出现的位置
这个可以结合map $O(nlogn)$处理
代码
#include <iostream> #include <cstdio> #include <map> #include <queue> using namespace std; #define int long long #define N 10000000 #define mid (l+r)/2 #define inf 1e15 int val[N],nxt[N],ver[N],root[N],maxn[N],pos[N],tag[N],lc[N],rc[N],cnt,arr[N]; map<int,int> last,appear; struct data { int val,l,r; }; bool operator <(data a,data b) { return a.val<b.val; } priority_queue<data> q; void modify(int &id,int l,int r,int tl,int tr,int val,int v) { lc[++cnt]=lc[id],rc[cnt]=rc[id],maxn[cnt]=maxn[id],pos[cnt]=pos[id],tag[cnt]=tag[id]; ver[cnt]=v; id=cnt; if(cnt>=N) throw 1; if(l>=tl&&r<=tr) { maxn[id]-=val; tag[id]+=val; return; } if(tl<=mid) modify(lc[id],l,mid,tl,tr,val,v); if(tr>mid) modify(rc[id],mid+1,r,tl,tr,val,v); maxn[id]=max(maxn[lc[id]],maxn[rc[id]]); if(maxn[lc[id]]==maxn[id]) pos[id]=pos[lc[id]]; else pos[id]=pos[rc[id]]; maxn[id]-=tag[id]; } void build(int &id,int l,int r) { ver[id]=1; id=++cnt; if(l==r) { pos[id]=l; maxn[id]=arr[l]; return; } build(lc[id],l,mid); build(rc[id],mid+1,r); maxn[id]=max(maxn[lc[id]],maxn[rc[id]]); if(maxn[lc[id]]==maxn[id]) pos[id]=pos[lc[id]]; else pos[id]=pos[rc[id]]; } signed main() { int n,k; //freopen("data.txt","r",stdin); cin>>n>>k; for(int i=1;i<=n;i++) { scanf("%lld",&val[i]); if(last[val[i]]) nxt[last[val[i]]]=i; last[val[i]]=i; } for(int i=1,t=0;i<=n;i++) { if(!appear[val[i]]) t+=val[i],appear[val[i]]=true; arr[i]=t; } build(root[1],1,n); q.push((data){maxn[1],1,pos[1]}); for(int i=1;i<=n;i++) if(!nxt[i]) nxt[i]=n+1; for(int i=2;i<=n;i++) { modify(root[i]=root[i-1],1,n,i-1,i-1,inf,i); modify(root[i],1,n,i,nxt[i-1]-1,val[i-1],i); q.push((data){maxn[root[i]],i,pos[root[i]]}); } k--; while(k--) { data a=q.top(); q.pop(); //cout<<a.val<<" "<<a.l<<" "<<a.r<<endl; modify(root[a.l],1,n,a.r,a.r,inf,a.l); q.push((data){maxn[root[a.l]],a.l,pos[root[a.l]]}); } data ans=q.top(); cout<<ans.val; }
看都看了,顺手点个推荐呗 :)