题解 润
hash<bitset<N>>
虽然复杂度正确,但冲突率十分感人,不建议使用
暴力可以 bitset + 记忆化
然后这个东西看起来就是要能想办法加一个或者合并两段
考虑区间 \([l, r]\) 的贡献
发现在(靠右)第一个 1 之后的部分是无用的
在第一个 1 和第二个 1 之间最低位有没有 +1 会影响 \(w(i)\)
在第二个 1 之前有没有 +1 不影响 \(w(i)\)
那么第二个 1 之前的贡献为
\[(r-i+2)2^{r-i+1}\times2^{i-l}
\]
前面部分是 \(w(i)\),后面是这一长度的小段数量
在第一个 1 和第二个 1 之间的贡献:
长度为 \(\lfloor\ \rfloor\) 的个数为 \(cnt_1\),长为 \(\lceil\ \rceil\) 的个数为 \(cnt_2\)
那么两种贡献分别为
\[(r-i+1)2^{r-i}\times cnt_1
\]
\[(r-i+2)2^{r-i+1}\times cnt_2
\]
这个 \(cnt\) 怎么求呢?
考虑较长的那种小段有多少个:
\[len-\lfloor\frac{len}{cnt}\rfloor*cnt=len\bmod{cnt}
\]
所以这个东西实际上就是 \([l, i-1]\) 中的 01 串构成的数字
然后给贡献化化式子发现是等比数列求和
发现最后只需要找到前两个 1 的位置,还要支持查询一段 01 串构成的数字
容易使用线段树实现
这里给出一份涵盖了除线段树外核心代码的 \(\require{cancel}\enclose{horizontalstrike}{O(nq)}\require{enclose}\) 实现
只需要简单加上一个线段树就行了
但是退役在即这棵线段树大概是再没机会打了
\(\tt NOI\) 延期了所以我回来写线段树了 /kk
算法复杂度 \(O((n+q)\log n)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ll long long
//#define int long long
int n, m;
char str[N], s[N];
const ll mod=998244353, inv2=(mod+1)>>1;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
// namespace test{
// map<ll, ll> mp;
// ll w(int n) {
// for (ll p=0; ; ++p) if ((1<<p)>=n) return (p+1)*(1<<p);
// }
// ll solve(int l, int r) {
// cout<<"solve: ["<<l<<','<<setw(2)<<r<<"] "<<bitset<10>(r-l+1)<<" len="<<setw(2)<<r-l+1<<" w="<<w(r-l+1)<<endl;
// // r-=l; l-=l;
// if (mp.find(r)!=mp.end()) return mp[r];
// if (l==r) return mp[r]=w(1);
// int mid=(l+r)>>1;
// return mp[r]=solve(l, mid)+solve(mid+1, r)+w(r-l+1);
// }
// }
// namespace task1{
// #undef N
// #define N 2005
// ll pw[N];
// bitset<N> s, t, mask;
// unordered_map<size_t, ll> mp;
// ll solve(bitset<N> s) {
// if (!s.any()) return 0;
// size_t h=hash<bitset<N>>()(s);
// if (mp.find(h)!=mp.end()) return mp[h];
// ll p=N-1;
// while (!s[p]) --p;
// if (s.count()!=1) ++p;
// ll ans=(p+1)*pw[p]%mod;
// if (s[0]==0) ans=(ans+2*solve(s>>1))%mod;
// else {
// if (s.count()==1) return ans;
// ans=(ans+solve(s>>1))%mod;
// bitset<N> t=s>>1;
// for (int i=0; ; ++i)
// if (!t[i]) {t[i]=1; break;}
// else t[i]=0;
// ans=(ans+solve(t))%mod;
// }
// return mp[h]=ans;
// }
// void solve() {
// pw[0]=1;
// for (int i=1; i<N; ++i) pw[i]=pw[i-1]*2%mod;
// for (int i=1; i<=n; ++i) s[i]=str[i]=='1';
// mask.set();
// for (int i=1,op,l,r; i<=m; ++i) {
// scanf("%d%d%d", &op, &l, &r);
// if (op==1) {
// for (int j=l; j<=r; ++j) s[j]=~s[j];
// }
// else if (op==2) {
// for (int j=l; j<=r; ++j) s[j]=0;
// }
// else if (op==3) {
// for (int j=l; j<=r; ++j) s[j]=1;
// }
// else {
// t=s; t>>=l;
// t&=mask>>N-1-r+l;
// printf("%lld\n", solve(t));
// }
// }
// }
// }
// namespace task2{
// #undef N
// #define N 2005
// ll pw[N];
// bitset<N> s, t, mask;
// unordered_map<size_t, ll> mp;
// ll solve(bitset<N> s, int high) {
// if (!s.any()) return 0;
// size_t h=hash<bitset<N>>()(s);
// if (mp.find(h)!=mp.end()) return mp[h];
// // ll p=N-1;
// // while (!s[p]) --p;
// ll p=high;
// if (s.count()!=1) ++p;
// ll ans=(p+1)*pw[p]%mod;
// if (s[0]==0) ans=(ans+2*solve(s>>1, high-1))%mod;
// else {
// if (s.count()==1) return ans;
// ans=(ans+solve(s>>1, high-1))%mod;
// bitset<N> t=s>>1;
// for (int i=0; ; ++i)
// if (!t[i]) {t[i]=1; high=max(high-1, i); break;}
// else t[i]=0;
// ans=(ans+solve(t, high))%mod;
// }
// return mp[h]=ans;
// }
// void solve() {
// pw[0]=1;
// for (int i=1; i<N; ++i) pw[i]=pw[i-1]*2%mod;
// for (int i=1; i<=n; ++i) s[i]=str[i]=='1';
// mask.set();
// for (int i=1,op,l,r; i<=m; ++i) {
// scanf("%d%d%d", &op, &l, &r);
// if (op==1) {
// for (int j=l; j<=r; ++j) s[j]=~s[j];
// }
// else if (op==2) {
// for (int j=l; j<=r; ++j) s[j]=0;
// }
// else if (op==3) {
// for (int j=l; j<=r; ++j) s[j]=1;
// }
// else {
// t=s; t>>=l;
// t&=mask>>N-1-r+l;
// int high=0;
// for (int i=r-1; ~i; --i) if (t[i]) {high=i; break;}
// printf("%lld\n", solve(t, high));
// }
// }
// }
// }
// namespace task3{
// ll pw[N];
// void solve() {
// pw[0]=1;
// for (int i=1; i<=n+1; ++i) pw[i]=pw[i-1]*2%mod;
// for (int i=1; i<=n; ++i) s[i]-='0';
// for (int i=1,op,l,r; i<=m; ++i) {
// scanf("%d%d%d", &op, &l, &r);
// if (op==1) {
// for (int j=l; j<=r; ++j) s[j]^=1;
// }
// else if (op==2) {
// for (int j=l; j<=r; ++j) s[j]=0;
// }
// else if (op==3) {
// for (int j=l; j<=r; ++j) s[j]=1;
// }
// else {
// int pos;
// while (!s[r]&&r>=l) --r;
// for (pos=r-1; !s[pos]&&pos>=l; --pos);
// ll ans=0, val=0, cnt1=0, cnt2=0;
// // cout<<"query: "; for (int j=l; j<=r; ++j) cout<<int(s[j]); cout<<endl;
// for (int j=l; j<=r; ++j) {
// // cout<<"j: "<<j<<endl;
// cnt2=val, cnt1=(pw[j-l]-cnt2)%mod;
// // cout<<"cnt: "<<cnt1<<' '<<cnt2<<endl;
// // cout<<"val: "<<(r-j+1)*pw[r-j]%mod<<' '<<(r-j+2)*pw[r-j+1]%mod<<endl;
// ans=(ans+(r-j+2)*pw[r-j+1]%mod*cnt2)%mod;
// // cout<<"add: "<<(r-j+2)*pw[r-j+1]%mod*cnt2<<endl;
// if (j<=pos) ans=(ans+(r-j+2)*pw[r-j+1]%mod*cnt1)%mod; //, cout<<"add: "<<(r-j+2)*pw[r-j+1]%mod*cnt1<<endl;
// else ans=(ans+(r-j+1)*pw[r-j]%mod*cnt1)%mod; //, cout<<"add: "<<(r-j+1)*pw[r-j]%mod*cnt1<<endl;
// if (j!=r) val=(val+s[j]*pw[j-l])%mod;
// }
// ans=(ans+val*2)%mod;
// printf("%lld\n", (ans%mod+mod)%mod);
// }
// }
// }
// }
namespace task{
#define tl(p) tl[p]
#define tr(p) tr[p]
ll pw[N], w[N], val[N<<2][2], mask[N<<2][2];
int tl[N<<2], tr[N<<2], len[N<<2], rev[N<<2], tag[N<<2];
inline void pushup(int p) {
val[p][0]=(val[p<<1][0]+pw[len[p<<1]]*val[p<<1|1][0])%mod;
val[p][1]=(val[p<<1][1]+pw[len[p<<1]]*val[p<<1|1][1])%mod;
}
inline void spread(int p) {
if (rev[p]) {
swap(val[p<<1][0], val[p<<1][1]);
if (~tag[p<<1]) tag[p<<1]^=1; else rev[p<<1]^=1;
swap(val[p<<1|1][0], val[p<<1|1][1]);
if (~tag[p<<1|1]) tag[p<<1|1]^=1; else rev[p<<1|1]^=1;
rev[p]=0;
}
if (~tag[p]) {
val[p<<1][0]=mask[p<<1][tag[p]]; val[p<<1][1]=mask[p<<1][tag[p]^1]; tag[p<<1]=tag[p];
val[p<<1|1][0]=mask[p<<1|1][tag[p]]; val[p<<1|1][1]=mask[p<<1|1][tag[p]^1]; tag[p<<1|1]=tag[p];
tag[p]=-1;
}
}
void build(int p, int l, int r) {
tl(p)=l; tr(p)=r; len[p]=r-l+1; tag[p]=-1;
if (l==r) {val[p][0]=s[l]; val[p][1]=s[l]^1; mask[p][1]=1; return ;}
int mid=(l+r)>>1;
build(p<<1, l, mid);
build(p<<1|1, mid+1, r);
pushup(p);
mask[p][1]=(mask[p<<1][1]+pw[len[p<<1]]*mask[p<<1|1][1])%mod;
}
void reverse(int p, int l, int r) {
if (l<=tl(p)&&r>=tr(p)) {
swap(val[p][0], val[p][1]);
if (~tag[p]) tag[p]^=1;
else rev[p]^=1;
return ;
}
spread(p);
int mid=(tl(p)+tr(p))>>1;
if (l<=mid) reverse(p<<1, l, r);
if (r>mid) reverse(p<<1|1, l, r);
pushup(p);
}
void cover(int p, int l, int r, int dat) {
if (l<=tl(p)&&r>=tr(p)) {val[p][0]=mask[p][dat]; val[p][1]=mask[p][dat^1]; tag[p]=dat; return ;}
spread(p);
int mid=(tl(p)+tr(p))>>1;
if (l<=mid) cover(p<<1, l, r, dat);
if (r>mid) cover(p<<1|1, l, r, dat);
pushup(p);
}
int query(int p, int l, int r) {
if (tl(p)==tr(p)) return tl(p)<=r&&val[p][0]?tl(p):-1;
spread(p);
int mid=(tl(p)+tr(p))>>1;
if (r>mid && val[p<<1|1][0]) {
int ans=query(p<<1|1, l, r);
if (ans==-1) return query(p<<1, l, r);
else return ans;
}
else return query(p<<1, l, r);
}
ll qval(int p, int l, int r) {
if (l<=tl(p)&&r>=tr(p)) return val[p][0];
spread(p);
int mid=(tl(p)+tr(p))>>1;
if (l<=mid&&r>mid) return (qval(p<<1, l, r)+pw[mid-max(tl(p), l)+1]*qval(p<<1|1, l, r))%mod;
else if (l<=mid) return qval(p<<1, l, r);
else return qval(p<<1|1, l, r);
}
ll qsum(ll l, ll r) {
if (l>r) return 0;
else return (r*(r+1)%mod*inv2-(l-1)*l%mod*inv2)%mod;
}
void solve() {
pw[0]=1;
for (int i=1; i<=n+1; ++i) pw[i]=pw[i-1]*2%mod;
for (int i=0; i<=n+1; ++i) w[i]=(i+1)*pw[i]%mod;
for (int i=1; i<=n; ++i) w[i]=(w[i]+w[i-1])%mod;
for (int i=1; i<=n; ++i) s[i]-='0';
build(1, 0, n);
for (int i=1,op,l,r; i<=m; ++i) {
// cout<<"i: "<<i<<endl;
// cout<<"s: "; for (int j=1; j<=n; ++j) cout<<(int)s[j]; cout<<endl;
// cout<<"t: "; for (int j=1; j<=n; ++j) cout<<qval(1, j, j); cout<<endl;
scanf("%d%d%d", &op, &l, &r);
if (op==1) {
// for (int j=l; j<=r; ++j) s[j]^=1;
reverse(1, l, r);
}
else if (op==2) {
// for (int j=l; j<=r; ++j) s[j]=0;
cover(1, l, r, 0);
}
else if (op==3) {
// for (int j=l; j<=r; ++j) s[j]=1;
cover(1, l, r, 1);
}
else {
int pos;
// cerr<<"lr: "<<l<<' '<<r<<endl;
// cerr<<"s: "; for (int j=1; j<=n; ++j) cerr<<(int)s[j]; cerr<<endl;
// cerr<<"t: "; for (int j=1; j<=n; ++j) cerr<<qval(1, j, j); cerr<<endl;
// while (!s[r]&&r>=l) --r;
// int t=r; while (!s[t]&&t>=l) --t;
r=max(query(1, l-1, r), l-1);
// cerr<<r<<' '<<t<<endl;
// assert(r==t);
if (r<l) {puts("0"); continue;}
// for (pos=r-1; !s[pos]&&pos>=l; --pos);
pos=max(query(1, l-1, r-1), l-1);
ll ans=(1ll*(pos-l+1)*(r+2)%mod-qsum(l, pos))*pw[r-l+1]%mod, val=0;
// cout<<"query: "; for (int j=l; j<=r; ++j) cout<<int(s[j]); cout<<endl;
// for (int j=l; j<=pos; ++j) val=(val+s[j]*pw[j-l])%mod;
if (l<=pos) val=qval(1, l, pos);
ans=(ans+val*(w[r-pos]-w[0])-val*w[r-pos-1])%mod;
ans=(ans+val*2+pw[r-l]*qsum(1, r-pos))%mod;
printf("%lld\n", (ans%mod+mod)%mod);
}
}
}
}
signed main()
{
freopen("run.in", "r", stdin);
freopen("run.out", "w", stdout);
scanf("%d%d%s", &n, &m, s+1);
// task1::solve();
// task2::solve();
// test::solve(1, 7);
task::solve();
return 0;
}