题解 fft

传送门

策略是每次取最高位
(赛时推出的策略是最高为 11 拆最高位,最高为 10 拆最高位的下一位,但仔细思考会发现这个其实等价于每次取最高位)

那么尝试写出贡献
最低位需要特判不好维护,考虑当成普通位来算再把算多的减掉
那么若第 \(i\) 位为 1,其贡献为

\[(i+1)2^{i+1}+\frac{i(i+1)}{2}2^i+2^i \]

化简得

\[(i^2+5i+6)2^{i-1} \]

然后 \(\sum i^22^{i-1}, \sum i2^{i-1}\) 都可以移位相减求出封闭形式
那么这个东西的前缀和可以 \(O(\log n)\) 求出
那么区间和也可以求了
那么珂朵莉树维护全 1 区间即可
复杂度 \(O(n\log n)\) 吧应该是

点击查看代码
// ubsan: undefined
// accoders
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 150010
#define fir first
#define sec second
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int k, m;
const ll mod=998244353, inv2=(mod+1)>>1;
inline void md(ll& a, ll b) {a+=b; a=a>=mod?a-mod:a;}
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
inline ll qsum(ll n) {n%=mod; return n*(n+1)%mod*inv2%mod;}

namespace force{
	ll s;
	int suf[N];
	void solve() {
		for (int i=1,l,r; i<=k; ++i) {
			l=read(); r=read();
			for (int j=l; j<=r; ++j) s+=1ll<<j;
		}
		for (int i=1,op,b; i<=m; ++i) {
			op=read(); b=read();
			if (op&1) s-=1ll<<b;
			else s+=1ll<<b;
			// cout<<"n: "<<s<<endl;
			cout<<"s: "<<bitset<30>(s)<<endl;
			ll t=s, ans=s;
			bool fir=1;
			suf[0]=t&1;
			for (int i=1; i<64; ++i) suf[i]=suf[i-1]+(t&(1ll<<i)?1:0);
			for (int i=63; i; --i) if (t&(1ll<<i)) {
				if (fir) {
					if (suf[i]==1) ans=(ans+(1ll<<i)%mod*i)%mod;
					else ans=(ans+(1ll<<i+1)%mod*(i+1))%mod;
					fir=0;
				}
				if (t&(1ll<<i-1)) {
					ans=(ans+(1ll<<i)%mod*qsum(i))%mod;
					if (suf[i-1]==1) ans=(ans+(1ll<<i-1)%mod*(i-1))%mod; //, cout<<"add2: "<<(1ll<<i-1)<<endl;
					else ans=(ans+(1ll<<i)%mod*i)%mod; //, cout<<"add2: "<<(1ll<<i)<<endl;
					// cout<<"add: "<<(1ll<<i)<<' '<<(1ll<<i)%mod*qsum(i)<<endl;
				}
				else {
					t|=1ll<<i-1;
					ans=(ans+(1ll<<i-1)%mod*qsum(i-1))%mod;
					if (suf[i-1]==0) ans=(ans+(1ll<<i-1)%mod*(i-1))%mod; //, cout<<"add2: "<<(1ll<<i-1)<<endl;
					else ans=(ans+(1ll<<i)%mod*i)%mod; //, cout<<"add2: "<<(1ll<<i)<<endl;
					// cout<<"add: "<<(1ll<<i-1)<<' '<<(1ll<<i-1)%mod*qsum(i-1)<<endl;
				}
			}
			printf("%lld\n", (ans%mod+mod)%mod);
		}
	}
}

namespace task1{
	bool s[N];
	set<int> st;
	ll pw[N], dat[N], dat2[N], val;
	void solve() {
		pw[0]=1;
		for (int i=1; i<=10005; ++i) pw[i]=(pw[i-1]<<1)%mod, dat[i]=pw[i]*i%mod, dat2[i]=pw[i]*qsum(i)%mod;
		for (int i=1,l,r; i<=k; ++i) {
			l=read(); r=read();
			for (int j=l; j<=r; ++j) s[j]=1, val=(val+pw[j])%mod, st.insert(j);
		}
		for (int i=1,op,b; i<=m; ++i) {
			op=read(); b=read();
			if (op&1) {
				val=((val-pw[b])%mod+mod)%mod;
				int pos;
				for (pos=b; !s[pos]; ++pos) s[pos]=1, st.insert(pos);
				s[pos]=0, st.erase(pos);
			}
			else {
				val=(val+pw[b])%mod;
				int pos;
				for (pos=b; s[pos]; ++pos) s[pos]=0, st.erase(pos);
				s[pos]=1, st.insert(pos);
			}
			// cout<<"s: "; for (int i=*st.rbegin(); ~i; --i) {
			// 	cout<<s[i];
			// 	if (s[i]) assert(st.find(i)!=st.end());
			// 	else assert(st.find(i)==st.end());
			// } cout<<endl;
			ll ans=val;
			int high=*st.rbegin(), low=*st.begin();
			if (high==low) md(ans, dat[high]);
			else md(ans, dat[high+1]);
			for (int i=high,lst=0; i; --i) if (s[i]||lst) {
				lst=0;
				if (s[i-1]) {
					md(ans, dat2[i]);
					if (i-1==low) md(ans, dat[i-1]);
					else md(ans, dat[i]);
				}
				else {
					lst=1;
					md(ans, dat2[i-1]);
					if (i-1<low) md(ans, dat[i-1]);
					else md(ans, dat[i]);
				}
			}
			printf("%lld\n", (ans%mod+mod)%mod);
		}
	}
}

namespace task{
	ll ans, now;
	struct node{mutable ll l, r, val;};
	inline bool operator < (node a, node b) {return a.l<b.l;}
	set<node> s;
	inline ll sum1(ll r) {return (now-1)%mod;}
	inline ll force_sum1(ll r) {ll ans=0; for (int i=1; i<=r; ++i) ans=(ans+qpow(2, i-1))%mod; return ans;}
	inline ll sum2(ll r) {return (r*now-sum1(r))%mod;}
	inline ll force_sum2(ll r) {ll ans=0; for (int i=1; i<=r; ++i) ans=(ans+i*qpow(2, i-1))%mod; return ans;}
	inline ll sum3(ll r) {return (r*r%mod*now+sum1(r)-2*sum2(r))%mod;}
	inline ll force_sum3(ll r) {ll ans=0; for (int i=1; i<=r; ++i) ans=(ans+i*i%mod*qpow(2, i-1))%mod; return ans;}
	inline ll qsum(ll r) {if (r<=0) return 0; now=qpow(2, r); return (sum3(r)+5*sum2(r)+6*sum1(r))%mod;}
	inline ll qval(int l, int r) {return (qsum(r)-qsum(l-1))%mod;}
	int higher_1(int x) {
		// cout<<"x: "<<x<<endl;
		auto it=s.upper_bound({x, 0, 0});
		if (it==s.begin()) return it->l;
		if ((--it)->r<x) return (++it)->l;
		return x;
	}
	int higher_0(int x) {
		auto it=s.upper_bound({x, 0, 0});
		if (it==s.begin()) return x;
		if ((--it)->r<x) return x;
		for (auto t=it; (++t)!=s.end()&&t->l==it->r+1; it=t);
		return it->r+1;
	}
	auto split(int x) {
		auto it=s.upper_bound({x, 0, 0});
		if (it==s.begin()) return it;
		if ((--it)->r<x) return ++it;
		if (it->l==x) return it;
		int l=it->l, r=it->r;
		s.erase(it);
		s.insert({l, x-1, qval(l, x-1)});
		return s.insert({x, r, qval(x, r)}).fir;
	}
	void erase(int l, int r) {
		if (l>r) return ;
		auto it2=split(r+1), it1=split(l);
		// s.erase(it1, it2);
		for (; it1!=it2; ans=(ans-it1->val)%mod,it1=s.erase(it1));
	}
	void assign(int l, int r) {
		if (l>r) return ;
		auto it2=split(r+1), it1=split(l);
		// s.erase(it1, it2);
		for (; it1!=it2; ans=(ans-it1->val)%mod,it1=s.erase(it1));
		s.insert({l, r, qval(l, r)});
		ans=(ans+qval(l, r))%mod;
	}
	void solve() {
		for (int i=1,l,r; i<=k; ++i) {
			l=read(); r=read();
			ans=(ans+qval(l, r))%mod;
			s.insert({l, r, qval(l, r)});
		}
		for (int i=1,op,b; i<=m; ++i) {
			// cout<<"i: "<<i<<endl;
			op=read(); b=read();
			if (op&1) {
				int pos=higher_1(b);
				// cout<<"pos: "<<pos<<endl;
				erase(pos, pos), assign(b, pos-1);
			}
			else {
				int pos=higher_0(b);
				assign(pos, pos), erase(b, pos-1);
			}
			// cout<<"s: "; for (auto it:s) cout<<"("<<it.l<<','<<it.r<<") "; cout<<endl;
			ll ctz=s.begin()->l;
			if (ctz) printf("%lld\n", ((ans-(ctz+1)*qpow(2, ctz+1))%mod+mod)%mod);
			else printf("%lld\n", ((ans+1)%mod+mod)%mod);
		}
	}
}

signed main()
{
	freopen("fft.in", "r", stdin);
	freopen("fft.out", "w", stdout);

	k=read(); m=read();
	// force::solve();
	// task1::solve();
	task::solve();

	return 0;
}
posted @ 2022-07-23 21:20  Administrator-09  阅读(1)  评论(0编辑  收藏  举报