题解 赢王

传送门

保证数据随机的那一档合法区间非常少

找性质:
image
于是可以求解一个区间的答案

image
于是只需要算 \(O(n)\) 个区间答案

image
需要将计算式写出,大力分类讨论拆 min

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 1000010
#define fir first
#define sec second
#define pb push_back
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, k;
int a[N];
const ll mod=998244353;

namespace force{
	ll pre[N], ans;
	map<vector<int>, bool> vis;
	queue<pair<vector<int>, int>> q;
	bool check(vector<int>& tem) {
		for (auto it:tem) if (it%k) return 0;
		return 1;
	}
	int bfs(vector<int> tem) {
		vis.clear();
		while (q.size()) q.pop();
		q.push({tem, 0});
		// cout<<"bfs: "; for (auto it:tem) cout<<it<<' '; cout<<endl;
		while (q.size()) {
			vector<int> u=q.front().fir; int dis=q.front().sec; q.pop();
			if (check(u)) return dis;
			for (int i=0; i+1<u.size(); ++i) {
				--u[i]; ++u[i+1];
				if (vis.find(u)==vis.end()) q.push({u, dis+1}), vis[u]=1;
				++u[i]; --u[i+1];
				++u[i]; --u[i+1];
				if (vis.find(u)==vis.end()) q.push({u, dis+1}), vis[u]=1;
				--u[i]; ++u[i+1];
			}
		}
		assert(0);
		return INF;
	}
	int calc(int l, int r) {
		if ((pre[r]-pre[l-1])%k) return -1;
		vector<int> tem;
		for (int i=l; i<=r; ++i) tem.pb(a[i]);
		return bfs(tem);
	}
	void solve() {
		for (int i=1; i<=n; ++i) pre[i]=pre[i-1]+a[i];
		for (int i=1; i<=n; ++i)
			for (int j=i; j<=n; ++j)
				ans+=calc(i, j);
		cout<<(ans%mod+mod)%mod<<endl;
	}
}

namespace task1{
	ll pre[N], ans;
	int f[110][220];
	const int dlt=110;
	#define f(a, b) f[a][b+dlt]
	int dp(vector<int> tem) {
		// cout<<"dp: "; for (auto it:tem) cout<<it<<' '; cout<<endl;
		int n=tem.size()-1;
		// cout<<"n: "<<n<<endl;
		memset(f, 0x3f, sizeof(f));
		f(0, 0)=0;
		f(1, tem[1])=f(1, tem[1]-k)=0;
		for (int i=1; i<=n; ++i) {
			for (int j=-k; j<=k; ++j) {
				// cout<<"ij: "<<i<<' '<<j<<' '<<f(i-1, j)<<endl;
				// cout<<"ij: "<<i<<' '<<j<<' '<<f(i, j)<<endl;
				if (j+tem[i]<=k) f(i, j+tem[i])=min(f(i, j+tem[i]), f(i-1, j)+abs(j));
				if (j-k+tem[i]>=-k) f(i, j-k+tem[i])=min(f(i, j-k+tem[i]), f(i-1, j)+abs(j));
			}
		}
		// f(n, k)=min(f(n, k), f(n-1, k-tem[n])+abs(k-tem[n]));
		// f(n, 0)=min(f(n, 0), f(n-1, -tem[n])+abs(-tem[n]));
		// cout<<"return: "<<min(f(n, 0), f(n, k))<<endl;
		return min(f(n, 0), f(n, k));
	}
	int calc(int l, int r) {
		if ((pre[r]-pre[l-1])%k) return -1;
		vector<int> tem; tem.pb(0);
		for (int i=l; i<=r; ++i) tem.pb(a[i]);
		return dp(tem);
	}
	void solve() {
		for (int i=1; i<=n; ++i) pre[i]=pre[i-1]+a[i];
		for (int i=1; i<=n; ++i)
			for (int j=i; j<=n; ++j)
				ans+=calc(i, j);
		cout<<(ans%mod+mod)%mod<<endl;
	}
}

namespace task2{
	ll ans, cnt;
	vector<int> s[N];
	pair<ll, ll> bit[N];
	struct que{int l, r; ll k, c;};
	vector<que> q[N];
	int pre[N], rk[N], uni[N], usiz;
	inline void add(int i, ll dat) {for (; i<=usiz; i+=i&-i) bit[i].fir+=dat, ++bit[i].sec;}
	inline pair<ll, ll> query(int l, int r) {
		pair<ll, ll> ans(0, 0); --l;
		while (r>l) ans.fir+=bit[r].fir, ans.sec+=bit[r].sec, r-=r&-r;
		while (l>r) ans.fir-=bit[l].fir, ans.sec-=bit[l].sec, l-=l&-l;
		return ans;
	}
	void query(int r, ll t, ll c) {
		// cout<<"query: "<<r<<' '<<t<<' '<<c<<endl;
		ll dlt=0;
		for (int i=1; i<=r; ++i)
			if (pre[i]<=t) dlt=(dlt+min(t-pre[i], k-(t-pre[i])))%mod;
			else dlt=(dlt+min(pre[i]-t, k-(pre[i]-t)))%mod;
		// cout<<"dlt: "<<dlt<<endl;
		ans=(ans+dlt*c)%mod;
	}
	void solve() {
		cnt=1ll*n*(n+1)/2;
		uni[++usiz]=0; s[1].pb(0);
		for (int i=1; i<=n; ++i) uni[++usiz]=pre[i]=(pre[i-1]+a[i])%k;
		// cout<<"pre: "; for (int i=1; i<=n; ++i) cout<<pre[i]<<' '; cout<<endl;
		sort(uni+1, uni+usiz+1);
		usiz=unique(uni+1, uni+usiz+1)-uni-1;
		for (int i=1; i<=n; ++i) s[rk[i]=lower_bound(uni+1, uni+usiz+1, pre[i])-uni].pb(i);
		for (int i=1; i<=usiz; ++i) {
			// cout<<"i: "; cout<<s[i][0]<<' ';
			ll t=s[i].size(); cnt-=t*(t-1)/2;
			for (int j=1; j<s[i].size(); ++j) {
				// cout<<s[i][j]<<' ';
				// cout<<s[i][j-1]+1<<' '<<s[i][j]<<' '<<1ll*j*((ll)s[i].size()-j)%mod<<endl;
				query(s[i][j-1], pre[s[i][j-1]], -1ll*j*((ll)s[i].size()-j)%mod);
				query(s[i][j], pre[s[i][j-1]], 1ll*j*((ll)s[i].size()-j)%mod);
			} //cout<<endl;
		}
		cout<<((ans-cnt)%mod+mod)%mod<<endl;
	}
}

namespace task{
	ll ans, cnt;
	vector<int> s[N];
	pair<ll, ll> bit[N];
	struct que{ll l, r, k, c;};
	vector<que> q[N];
	int pre[N], rk[N], uni[N], usiz;
	inline void add(int i, ll dat) {for (; i<=usiz; i+=i&-i) bit[i].fir+=dat, ++bit[i].sec;}
	inline pair<ll, ll> query(int l, int r) {
		pair<ll, ll> ans(0, 0); --l;
		while (r>l) ans.fir+=bit[r].fir, ans.sec+=bit[r].sec, r-=r&-r;
		while (l>r) ans.fir-=bit[l].fir, ans.sec-=bit[l].sec, l-=l&-l;
		return ans;
	}
	// void query(int r, ll t, ll c) {
	// 	ll dlt=0;
	// 	for (int i=1; i<=r; ++i) {
	// 		if (t<=pre[i] && pre[i]<=t+k/2) dlt+=pre[i]-t;
	// 		if (t+k/2+1<=pre[i] && pre[i]<=k) dlt+=-pre[i]+t+k;
	// 		if (t-k/2<=pre[i] && pre[i]<=t-1) dlt+=-pre[i]+t;
	// 		if (0<=pre[i] && pre[i]<=t-k/2-1) dlt+=pre[i]+k-t;
	// 	}
	// 	ans=(ans+dlt*c)%mod;
	// }
	void query(int r, ll t, ll c) {
		q[r].pb({t, t+k/2, c, c*-t%mod});
		q[r].pb({t+k/2+1, k, -c, c*(t+k)%mod});
		q[r].pb({t-k/2, t-1, -c, c*t%mod});
		q[r].pb({0, t-k/2-1, c, c*(k-t)%mod});
	}
	void solve() {
		cnt=1ll*n*(n+1)/2;
		uni[++usiz]=0; s[1].pb(0);
		for (int i=1; i<=n; ++i) uni[++usiz]=pre[i]=(pre[i-1]+a[i])%k;
		// cout<<"pre: "; for (int i=1; i<=n; ++i) cout<<pre[i]<<' '; cout<<endl;
		sort(uni+1, uni+usiz+1);
		usiz=unique(uni+1, uni+usiz+1)-uni-1;
		for (int i=1; i<=n; ++i) s[rk[i]=lower_bound(uni+1, uni+usiz+1, pre[i])-uni].pb(i);
		for (int i=1; i<=usiz; ++i) {
			// cout<<"i: "; cout<<s[i][0]<<' ';
			ll t=s[i].size(); cnt-=t*(t-1)/2;
			for (int j=1; j<s[i].size(); ++j) {
				// cout<<s[i][j]<<' ';
				// cout<<s[i][j-1]+1<<' '<<s[i][j]<<' '<<1ll*j*((ll)s[i].size()-j)%mod<<endl;
				query(s[i][j-1], pre[s[i][j-1]], -1ll*j*((ll)s[i].size()-j)%mod);
				query(s[i][j], pre[s[i][j-1]], 1ll*j*((ll)s[i].size()-j)%mod);
			} //cout<<endl;
		}
		for (int i=1; i<=n; ++i) {
			// cout<<"i: "<<i<<endl;
			add(rk[i], pre[i]);
			for (auto it:q[i]) {
				// cout<<"("<<it.l<<','<<it.r<<") ";
				if (it.l>it.r) continue;
				it.l=lower_bound(uni+1, uni+usiz+1, it.l)-uni, it.r=upper_bound(uni+1, uni+usiz+1, it.r)-uni-1;
				if (it.l>it.r) continue;
				pair<ll, ll> t=query(it.l, it.r);
				ans=(ans+t.fir*it.k+t.sec*it.c)%mod;
			} //cout<<endl;
		}
		// cout<<ans<<endl;
		cout<<((ans-cnt)%mod+mod)%mod<<endl;
	}
}

signed main()
{
	freopen("win.in", "r", stdin);
	freopen("win.out", "w", stdout);

	n=read(); k=read();
	for (int i=1; i<=n; ++i) a[i]=read();
	// if (n<=10) force::solve();
	// else task1::solve();
	task::solve();
	
	return 0;
}
posted @ 2022-03-12 21:26  Administrator-09  阅读(1)  评论(0编辑  收藏  举报