题解 润

传送门

  • hash<bitset<N>> 虽然复杂度正确,但冲突率十分感人,不建议使用

暴力可以 bitset + 记忆化

然后这个东西看起来就是要能想办法加一个或者合并两段
考虑区间 \([l, r]\) 的贡献
发现在(靠右)第一个 1 之后的部分是无用的
在第一个 1 和第二个 1 之间最低位有没有 +1 会影响 \(w(i)\)
在第二个 1 之前有没有 +1 不影响 \(w(i)\)
那么第二个 1 之前的贡献为

\[(r-i+2)2^{r-i+1}\times2^{i-l} \]

前面部分是 \(w(i)\),后面是这一长度的小段数量
在第一个 1 和第二个 1 之间的贡献:
长度为 \(\lfloor\ \rfloor\) 的个数为 \(cnt_1\),长为 \(\lceil\ \rceil\) 的个数为 \(cnt_2\)
那么两种贡献分别为

\[(r-i+1)2^{r-i}\times cnt_1 \]

\[(r-i+2)2^{r-i+1}\times cnt_2 \]

这个 \(cnt\) 怎么求呢?
考虑较长的那种小段有多少个:

\[len-\lfloor\frac{len}{cnt}\rfloor*cnt=len\bmod{cnt} \]

所以这个东西实际上就是 \([l, i-1]\) 中的 01 串构成的数字
然后给贡献化化式子发现是等比数列求和
发现最后只需要找到前两个 1 的位置,还要支持查询一段 01 串构成的数字
容易使用线段树实现
这里给出一份涵盖了除线段树外核心代码的 \(\require{cancel}\enclose{horizontalstrike}{O(nq)}\require{enclose}\) 实现
只需要简单加上一个线段树就行了
但是退役在即这棵线段树大概是再没机会打了
\(\tt NOI\) 延期了所以我回来写线段树了 /kk
算法复杂度 \(O((n+q)\log n)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ll long long
//#define int long long

int n, m;
char str[N], s[N];
const ll mod=998244353, inv2=(mod+1)>>1;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

// namespace test{
// 	map<ll, ll> mp;
// 	ll w(int n) {
// 		for (ll p=0; ; ++p) if ((1<<p)>=n) return (p+1)*(1<<p);
// 	}
// 	ll solve(int l, int r) {
// 		cout<<"solve: ["<<l<<','<<setw(2)<<r<<"] "<<bitset<10>(r-l+1)<<" len="<<setw(2)<<r-l+1<<" w="<<w(r-l+1)<<endl;
// 		// r-=l; l-=l;
// 		if (mp.find(r)!=mp.end()) return mp[r];
// 		if (l==r) return mp[r]=w(1);
// 		int mid=(l+r)>>1;
// 		return mp[r]=solve(l, mid)+solve(mid+1, r)+w(r-l+1);
// 	}
// }

// namespace task1{
// 	#undef N
// 	#define N 2005
// 	ll pw[N];
// 	bitset<N> s, t, mask;
// 	unordered_map<size_t, ll> mp;
// 	ll solve(bitset<N> s) {
// 		if (!s.any()) return 0;
// 		size_t h=hash<bitset<N>>()(s);
// 		if (mp.find(h)!=mp.end()) return mp[h];
// 		ll p=N-1;
// 		while (!s[p]) --p;
// 		if (s.count()!=1) ++p;
// 		ll ans=(p+1)*pw[p]%mod;
// 		if (s[0]==0) ans=(ans+2*solve(s>>1))%mod;
// 		else {
// 			if (s.count()==1) return ans;
// 			ans=(ans+solve(s>>1))%mod;
// 			bitset<N> t=s>>1;
// 			for (int i=0; ; ++i)
// 				if (!t[i]) {t[i]=1; break;}
// 				else t[i]=0;
// 			ans=(ans+solve(t))%mod;
// 		}
// 		return mp[h]=ans;
// 	}
// 	void solve() {
// 		pw[0]=1;
// 		for (int i=1; i<N; ++i) pw[i]=pw[i-1]*2%mod;
// 		for (int i=1; i<=n; ++i) s[i]=str[i]=='1';
// 		mask.set();
// 		for (int i=1,op,l,r; i<=m; ++i) {
// 			scanf("%d%d%d", &op, &l, &r);
// 			if (op==1) {
// 				for (int j=l; j<=r; ++j) s[j]=~s[j];
// 			}
// 			else if (op==2) {
// 				for (int j=l; j<=r; ++j) s[j]=0;
// 			}
// 			else if (op==3) {
// 				for (int j=l; j<=r; ++j) s[j]=1;
// 			}
// 			else {
// 				t=s; t>>=l;
// 				t&=mask>>N-1-r+l;
// 				printf("%lld\n", solve(t));
// 			}
// 		}
// 	}
// }

// namespace task2{
// 	#undef N
// 	#define N 2005
// 	ll pw[N];
// 	bitset<N> s, t, mask;
// 	unordered_map<size_t, ll> mp;
// 	ll solve(bitset<N> s, int high) {
// 		if (!s.any()) return 0;
// 		size_t h=hash<bitset<N>>()(s);
// 		if (mp.find(h)!=mp.end()) return mp[h];
// 		// ll p=N-1;
// 		// while (!s[p]) --p;
// 		ll p=high;
// 		if (s.count()!=1) ++p;
// 		ll ans=(p+1)*pw[p]%mod;
// 		if (s[0]==0) ans=(ans+2*solve(s>>1, high-1))%mod;
// 		else {
// 			if (s.count()==1) return ans;
// 			ans=(ans+solve(s>>1, high-1))%mod;
// 			bitset<N> t=s>>1;
// 			for (int i=0; ; ++i)
// 				if (!t[i]) {t[i]=1; high=max(high-1, i); break;}
// 				else t[i]=0;
// 			ans=(ans+solve(t, high))%mod;
// 		}
// 		return mp[h]=ans;
// 	}
// 	void solve() {
// 		pw[0]=1;
// 		for (int i=1; i<N; ++i) pw[i]=pw[i-1]*2%mod;
// 		for (int i=1; i<=n; ++i) s[i]=str[i]=='1';
// 		mask.set();
// 		for (int i=1,op,l,r; i<=m; ++i) {
// 			scanf("%d%d%d", &op, &l, &r);
// 			if (op==1) {
// 				for (int j=l; j<=r; ++j) s[j]=~s[j];
// 			}
// 			else if (op==2) {
// 				for (int j=l; j<=r; ++j) s[j]=0;
// 			}
// 			else if (op==3) {
// 				for (int j=l; j<=r; ++j) s[j]=1;
// 			}
// 			else {
// 				t=s; t>>=l;
// 				t&=mask>>N-1-r+l;
// 				int high=0;
// 				for (int i=r-1; ~i; --i) if (t[i]) {high=i; break;}
// 				printf("%lld\n", solve(t, high));
// 			}
// 		}
// 	}
// }

// namespace task3{
// 	ll pw[N];
// 	void solve() {
// 		pw[0]=1;
// 		for (int i=1; i<=n+1; ++i) pw[i]=pw[i-1]*2%mod;
// 		for (int i=1; i<=n; ++i) s[i]-='0';
// 		for (int i=1,op,l,r; i<=m; ++i) {
// 			scanf("%d%d%d", &op, &l, &r);
// 			if (op==1) {
// 				for (int j=l; j<=r; ++j) s[j]^=1;
// 			}
// 			else if (op==2) {
// 				for (int j=l; j<=r; ++j) s[j]=0;
// 			}
// 			else if (op==3) {
// 				for (int j=l; j<=r; ++j) s[j]=1;
// 			}
// 			else {
// 				int pos;
// 				while (!s[r]&&r>=l) --r;
// 				for (pos=r-1; !s[pos]&&pos>=l; --pos);
// 				ll ans=0, val=0, cnt1=0, cnt2=0;
// 				// cout<<"query: "; for (int j=l; j<=r; ++j) cout<<int(s[j]); cout<<endl;
// 				for (int j=l; j<=r; ++j) {
// 					// cout<<"j: "<<j<<endl;
// 					cnt2=val, cnt1=(pw[j-l]-cnt2)%mod;
// 					// cout<<"cnt: "<<cnt1<<' '<<cnt2<<endl;
// 					// cout<<"val: "<<(r-j+1)*pw[r-j]%mod<<' '<<(r-j+2)*pw[r-j+1]%mod<<endl;
// 					ans=(ans+(r-j+2)*pw[r-j+1]%mod*cnt2)%mod;
// 					// cout<<"add: "<<(r-j+2)*pw[r-j+1]%mod*cnt2<<endl;
// 					if (j<=pos) ans=(ans+(r-j+2)*pw[r-j+1]%mod*cnt1)%mod; //, cout<<"add: "<<(r-j+2)*pw[r-j+1]%mod*cnt1<<endl;
// 					else ans=(ans+(r-j+1)*pw[r-j]%mod*cnt1)%mod; //, cout<<"add: "<<(r-j+1)*pw[r-j]%mod*cnt1<<endl;
// 					if (j!=r) val=(val+s[j]*pw[j-l])%mod;
// 				}
// 				ans=(ans+val*2)%mod;
// 				printf("%lld\n", (ans%mod+mod)%mod);
// 			}
// 		}
// 	}
// }

namespace task{
	#define tl(p) tl[p]
	#define tr(p) tr[p]
	ll pw[N], w[N], val[N<<2][2], mask[N<<2][2];
	int tl[N<<2], tr[N<<2], len[N<<2], rev[N<<2], tag[N<<2];
	inline void pushup(int p) {
		val[p][0]=(val[p<<1][0]+pw[len[p<<1]]*val[p<<1|1][0])%mod;
		val[p][1]=(val[p<<1][1]+pw[len[p<<1]]*val[p<<1|1][1])%mod;
	}
	inline void spread(int p) {
		if (rev[p]) {
			swap(val[p<<1][0], val[p<<1][1]);
			if (~tag[p<<1]) tag[p<<1]^=1; else rev[p<<1]^=1;
			swap(val[p<<1|1][0], val[p<<1|1][1]);
			if (~tag[p<<1|1]) tag[p<<1|1]^=1; else rev[p<<1|1]^=1;
			rev[p]=0;
		}
		if (~tag[p]) {
			val[p<<1][0]=mask[p<<1][tag[p]]; val[p<<1][1]=mask[p<<1][tag[p]^1]; tag[p<<1]=tag[p];
			val[p<<1|1][0]=mask[p<<1|1][tag[p]]; val[p<<1|1][1]=mask[p<<1|1][tag[p]^1]; tag[p<<1|1]=tag[p];
			tag[p]=-1;
		}
	}
	void build(int p, int l, int r) {
		tl(p)=l; tr(p)=r; len[p]=r-l+1; tag[p]=-1;
		if (l==r) {val[p][0]=s[l]; val[p][1]=s[l]^1; mask[p][1]=1; return ;}
		int mid=(l+r)>>1;
		build(p<<1, l, mid);
		build(p<<1|1, mid+1, r);
		pushup(p);
		mask[p][1]=(mask[p<<1][1]+pw[len[p<<1]]*mask[p<<1|1][1])%mod;
	}
	void reverse(int p, int l, int r) {
		if (l<=tl(p)&&r>=tr(p)) {
			swap(val[p][0], val[p][1]);
			if (~tag[p]) tag[p]^=1;
			else rev[p]^=1;
			return ;
		}
		spread(p);
		int mid=(tl(p)+tr(p))>>1;
		if (l<=mid) reverse(p<<1, l, r);
		if (r>mid) reverse(p<<1|1, l, r);
		pushup(p);
	}
	void cover(int p, int l, int r, int dat) {
		if (l<=tl(p)&&r>=tr(p)) {val[p][0]=mask[p][dat]; val[p][1]=mask[p][dat^1]; tag[p]=dat; return ;}
		spread(p);
		int mid=(tl(p)+tr(p))>>1;
		if (l<=mid) cover(p<<1, l, r, dat);
		if (r>mid) cover(p<<1|1, l, r, dat);
		pushup(p);
	}
	int query(int p, int l, int r) {
		if (tl(p)==tr(p)) return tl(p)<=r&&val[p][0]?tl(p):-1;
		spread(p);
		int mid=(tl(p)+tr(p))>>1;
		if (r>mid && val[p<<1|1][0]) {
			int ans=query(p<<1|1, l, r);
			if (ans==-1) return query(p<<1, l, r);
			else return ans;
		}
		else return query(p<<1, l, r);
	}
	ll qval(int p, int l, int r) {
		if (l<=tl(p)&&r>=tr(p)) return val[p][0];
		spread(p);
		int mid=(tl(p)+tr(p))>>1;
		if (l<=mid&&r>mid) return (qval(p<<1, l, r)+pw[mid-max(tl(p), l)+1]*qval(p<<1|1, l, r))%mod;
		else if (l<=mid) return qval(p<<1, l, r);
		else return qval(p<<1|1, l, r);
	}
	ll qsum(ll l, ll r) {
		if (l>r) return 0;
		else return (r*(r+1)%mod*inv2-(l-1)*l%mod*inv2)%mod;
	}
	void solve() {
		pw[0]=1;
		for (int i=1; i<=n+1; ++i) pw[i]=pw[i-1]*2%mod;
		for (int i=0; i<=n+1; ++i) w[i]=(i+1)*pw[i]%mod;
		for (int i=1; i<=n; ++i) w[i]=(w[i]+w[i-1])%mod;
		for (int i=1; i<=n; ++i) s[i]-='0';
		build(1, 0, n);
		for (int i=1,op,l,r; i<=m; ++i) {
			// cout<<"i: "<<i<<endl;
			// cout<<"s: "; for (int j=1; j<=n; ++j) cout<<(int)s[j]; cout<<endl;
			// cout<<"t: "; for (int j=1; j<=n; ++j) cout<<qval(1, j, j); cout<<endl;
			scanf("%d%d%d", &op, &l, &r);
			if (op==1) {
				// for (int j=l; j<=r; ++j) s[j]^=1;
				reverse(1, l, r);
			}
			else if (op==2) {
				// for (int j=l; j<=r; ++j) s[j]=0;
				cover(1, l, r, 0);
			}
			else if (op==3) {
				// for (int j=l; j<=r; ++j) s[j]=1;
				cover(1, l, r, 1);
			}
			else {
				int pos;
				// cerr<<"lr: "<<l<<' '<<r<<endl;
				// cerr<<"s: "; for (int j=1; j<=n; ++j) cerr<<(int)s[j]; cerr<<endl;
				// cerr<<"t: "; for (int j=1; j<=n; ++j) cerr<<qval(1, j, j); cerr<<endl;
				// while (!s[r]&&r>=l) --r;
				// int t=r; while (!s[t]&&t>=l) --t;
				r=max(query(1, l-1, r), l-1);
				// cerr<<r<<' '<<t<<endl;
				// assert(r==t);
				if (r<l) {puts("0"); continue;}
				// for (pos=r-1; !s[pos]&&pos>=l; --pos);
				pos=max(query(1, l-1, r-1), l-1);
				ll ans=(1ll*(pos-l+1)*(r+2)%mod-qsum(l, pos))*pw[r-l+1]%mod, val=0;
				// cout<<"query: "; for (int j=l; j<=r; ++j) cout<<int(s[j]); cout<<endl;
				// for (int j=l; j<=pos; ++j) val=(val+s[j]*pw[j-l])%mod;
				if (l<=pos) val=qval(1, l, pos);
				ans=(ans+val*(w[r-pos]-w[0])-val*w[r-pos-1])%mod;
				ans=(ans+val*2+pw[r-l]*qsum(1, r-pos))%mod;
				printf("%lld\n", (ans%mod+mod)%mod);
			}
		}
	}
}

signed main()
{
	freopen("run.in", "r", stdin);
	freopen("run.out", "w", stdout);

	scanf("%d%d%s", &n, &m, s+1);
	// task1::solve();
	// task2::solve();
	// test::solve(1, 7);
	task::solve();

	return 0;
}
posted @ 2022-06-27 20:32  Administrator-09  阅读(3)  评论(0编辑  收藏  举报