题解 rfrqwq
神必卡常毒瘤出题人是屑
从复杂度猜测需要根号做法,那么考虑分块
发现对于一个询问 \([l, r]\),\(\forall i\in[l, r]\and a_i\neq a_{i+1}\) 产生 \(cnt_i\times rcnt_{i+1}\) 的贡献
这里 \(cnt, rcnt\) 为询问区间前/后缀询问颜色的出现次数
那么关键性质是一个询问可以拆分成若干个块,相邻块间是可以 \(O(1)\) 合并的
因为一个位置 \(i\) 的贡献实际上是 \((pre+cnt_i)(rcnt_{i+1}+suf)\)
那么拆一下贡献会发现需要维护块内每种颜色的 \(\sum [a_i\neq a_{i+1}]\) 和 \(\sum cnt_i\) 和 \(\sum rcnt_i\) 和 \(\sum cnt_i\times rcnt_{i+1}\)
(这里均要加上 \(a_i\neq a_{i+1}\) 的限制)
然后合并的话按照定义合并即可
发现 \(r\) 处的 \(a_{i+1}\) 不在这个块内,需要特别保证 \(a_{i+1}\) 的值是正确的
修改用一个 \(O(\sqrt n)-O(1)\) 的分块维护
查询暴力重构散块即可
离线之后将每个询问挂到与之相关的修改上可以省掉 hash 表,否则卡不过去
但是离线我也没把握卡过去所以就不卡了
然后离线的话空间复杂度比较大
那么发现每个块之间是独立的,对 \(O(\sqrt n)\) 个块每个做一次就好了
复杂度 \(O(m\sqrt n)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 500010
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, m;
int a[N];
namespace force{
ll cnt[N];
void solve() {
for (int i=1,l,r,x; i<=m; ++i) {
if (read()&1) {
l=read(); r=read(); x=read();
for (int j=l; j<=r; ++j) a[j]=x;
}
else {
l=read(); r=read(); x=read();
ll ans=0; int sum=0;
cnt[l-1]=0;
for (int j=l; j<=r; ++j) sum+=(cnt[j]=(a[j]==x));
for (int j=l; j<=r; ++j) cnt[j]+=cnt[j-1];
for (int j=l; j<r; ++j) if (a[j]!=a[j+1])
ans+=cnt[j]*(sum-cnt[j]);
printf("%lld\n", ans);
}
}
}
}
namespace task{
bool vis[N];
ll cnt[N], rcnt[N], ans[N];
struct ques{int op, l, r, x;}que[N];
int bel[N], ls[N], rs[N], lst[N], nxt[N], pos[N], pre[N], diff[N], tag[N], sqr;
struct data{ll len, siz, cnt, rcnt, prod;}sum[N];
inline data operator + (data a, data b) {
return {a.len+b.len, a.siz+b.siz, a.cnt+b.cnt+a.siz*b.len, a.rcnt+b.rcnt+a.len*b.siz, a.prod+b.prod+a.cnt*b.siz+a.siz*b.rcnt};
}
unordered_map<int, data> mp[710];
inline int qval(int i) {return tag[bel[i]]?tag[bel[i]]:a[i];}
inline void spread(int id) {
if (!tag[id]) return ;
for (int i=ls[id]; i<=rs[id]; ++i) a[i]=tag[id];
tag[id]=0;
}
inline void upd(int l, int r, int x) {
int sid=bel[l], eid=bel[r];
spread(sid); spread(eid);
if (sid==eid) {for (int i=l; i<=r; ++i) a[i]=x;}
else {
for (int i=l; bel[i]==sid; ++i) a[i]=x;
for (int i=sid+1; i<eid; ++i) tag[i]=x;
for (int i=r; bel[i]==eid; --i) a[i]=x;
}
}
void rebuild(int id) {
if (!id) return ;
mp[id].clear();
spread(id); a[rs[id]+1]=qval(rs[id]+1);
for (int i=ls[id]; i<=rs[id]; ++i) cnt[i]=cnt[lst[i]=pos[a[i]]]+1, pos[a[i]]=i;
for (int i=ls[id]; i<=rs[id]; ++i) pos[a[i]]=0;
for (int i=rs[id]; i>=ls[id]; --i) rcnt[i]=rcnt[nxt[i]=pos[a[i]]]+1, pos[a[i]]=i;
for (int i=ls[id]; i<=rs[id]; ++i) pre[i]=pre[i-1]+(a[i]!=a[i+1]), ++sum[a[i]].siz;
for (int i=ls[id]; i<=rs[id]; ++i) {
ll len=pre[(nxt[i]?nxt[i]:rs[id]+1)-1]-pre[i-1];
if (!len) continue;
sum[a[i]].cnt+=cnt[i]*len;
sum[a[i]].rcnt+=(rcnt[i]-1)*len;
sum[a[i]].prod+=cnt[i]*(rcnt[i]-1)*len;
}
for (int i=ls[id]; i<=rs[id]; ++i) if (!vis[a[i]]) {
sum[a[i]].len=pre[rs[id]]-pre[ls[id]-1];
sum[a[i]].rcnt+=1ll*rcnt[i]*(pre[i-1]-pre[ls[id]-1]);
mp[id][a[i]]=sum[a[i]]; vis[a[i]]=1;
}
for (int i=ls[id]; i<=rs[id]; ++i) {
cnt[i]=rcnt[i]=lst[i]=nxt[i]=pos[a[i]]=pre[i]=0;
vis[a[i]]=0; sum[a[i]]={0, 0, 0, 0, 0};
}
}
data build(int l, int r, int col, int op) {
if (l>r) return {0, 0, 0, 0, 0};
a[r+1]=qval(r+1);
int rec=INF;
for (int i=l; i<=r; ++i) cnt[i]=cnt[lst[i]=pos[a[i]]]+1, pos[a[i]]=i;
for (int i=l; i<=r; ++i) pos[a[i]]=0;
for (int i=r; i>=l; --i) {
rcnt[i]=rcnt[nxt[i]=pos[a[i]]]+1, pos[a[i]]=i;
if (a[i]==col) rec=i;
}
for (int i=l; i<=r; ++i) pre[i]=pre[i-1]+(a[i]!=a[i+1]), ++sum[a[i]].siz;
for (int i=l; i<r+op; ++i) {
// cout<<"i: "<<i<<endl;
// cout<<(nxt[i]?nxt[i]:r+op)-1<<endl;
ll len=pre[(nxt[i]?nxt[i]:r+op)-1]-pre[i-1];
if (!len) continue;
// cout<<"len: "<<len<<endl;
sum[a[i]].cnt+=cnt[i]*len;
sum[a[i]].rcnt+=(rcnt[i]-1)*len;
sum[a[i]].prod+=cnt[i]*(rcnt[i]-1)*len;
}
if (sum[col].siz) sum[col].rcnt+=1ll*rcnt[rec]*(pre[rec-1]-pre[l-1]);
data ans=sum[col]; ans.len=pre[r-(op^1)]-pre[l-1];
for (int i=l; i<=r; ++i) {
cnt[i]=rcnt[i]=lst[i]=nxt[i]=pos[a[i]]=pre[i]=0;
sum[a[i]]={0, 0, 0, 0, 0};
}
return ans;
}
void qdif(int id) {
if (!id) return ;
diff[id]=0; spread(id); a[rs[id]+1]=qval(rs[id]+1);
for (int j=ls[id]; j<=rs[id]; ++j) if (a[j]!=a[j+1]) ++diff[id];
}
void solve() {
sqr=sqrt(n);
for (int i=1; i<=n; ++i) bel[i]=(i-1)/sqr+1;
for (int i=1; i<=n; ++i) rs[bel[i]]=i;
for (int i=n; i; --i) ls[bel[i]]=i;
for (int i=1; i<=bel[n]; ++i) rebuild(i);
for (int i=1; i<n; ++i) if (a[i]!=a[i+1]) ++diff[bel[i]];
for (int i=1; i<=m; ++i) que[i].op=read(), que[i].l=read(), que[i].r=read(), que[i].x=read();
for (int i=1,l,r,x; i<=m; ++i) {
if (que[i].op&1) {
l=que[i].l; r=que[i].r; x=que[i].x;
// for (int j=l; j<=r; ++j) a[j]=x;
upd(l, r, x);
int sid=bel[l], eid=bel[r];
// for (int j=max(sid-1, 1); j<=eid; ++j) rebuild(j);
rebuild(l==ls[sid]?sid-1:sid), rebuild(eid);
for (int j=sid+(l!=ls[sid]); j<eid; ++j) mp[j].clear(), mp[j][x]={0, rs[j]-ls[j]+1, 0, 0, 0};
// memset(diff, 0, sizeof(diff));
// for (int j=1; j<n; ++j) if (a[j]!=a[j+1]) ++diff[bel[j]];
qdif(sid-(l==ls[sid])), qdif(eid);
for (int j=sid+(l!=ls[sid]); j<eid; ++j) diff[j]=0;
}
else {
l=que[i].l; r=que[i].r; x=que[i].x;
// int mid=(l+r)>>1;
// cout<<"mid: "<<mid<<endl;
// data tem=build(l, mid, x, 1);
// cout<<"tem: ("<<tem.len<<','<<tem.siz<<','<<tem.cnt<<','<<tem.rcnt<<','<<tem.prod<<")"<<endl;
// tem=build(mid+1, r, x, 0);
// cout<<"tem: ("<<tem.len<<','<<tem.siz<<','<<tem.cnt<<','<<tem.rcnt<<','<<tem.prod<<")"<<endl;
// data ans=build(l, mid, x, 1)+build(mid+1, r, x, 0);
// cout<<"ans: ("<<ans.len<<','<<ans.siz<<','<<ans.cnt<<','<<ans.rcnt<<','<<ans.prod<<")"<<endl;
// data std=build(l, r, x, 0);
// cout<<"std: ("<<std.len<<','<<std.siz<<','<<std.cnt<<','<<std.rcnt<<','<<std.prod<<")"<<endl;
int sid=bel[l], eid=bel[r];
spread(sid), spread(eid);
if (sid==eid) printf("%lld\n", build(l, r, x, 0).prod);
else {
data ans=build(l, rs[sid], x, 1);
for (int i=sid+1; i<eid; ++i)
if (mp[i].find(x)!=mp[i].end()) {
// data tem=build(ls[i], rs[i], x, 1);
// cout<<"tem: ("<<tem.len<<','<<tem.siz<<','<<tem.cnt<<','<<tem.rcnt<<','<<tem.prod<<")"<<endl;
// data bkp=mp[i][x];
// cout<<"bkp: ("<<bkp.len<<','<<bkp.siz<<','<<bkp.cnt<<','<<bkp.rcnt<<','<<bkp.prod<<")"<<endl;
ans=ans+mp[i][x];
}
else {
// data tem=build(ls[i], rs[i], x, 1);
// cout<<"tem: ("<<tem.len<<','<<tem.siz<<','<<tem.cnt<<','<<tem.rcnt<<','<<tem.prod<<")"<<endl;
ans=ans+(data){diff[i], 0, 0, 0, 0};
}
ans=ans+build(ls[eid], r, x, 0);
printf("%lld\n", ans.prod);
}
}
}
}
}
signed main()
{
freopen("c.in", "r", stdin);
freopen("c.out", "w", stdout);
n=read(); m=read();
for (int i=1; i<=n; ++i) a[i]=read();
// force::solve();
task::solve();
return 0;
}