ZROJ87 树状数组 - 数位dp -

题目链接:http://zhengruioi.com/problem/87

题解:
首先考虑 \(f(l,r)\) 代表什么,官方题解很详细了就不再赘述了:

因此我们要求的就是对于所有 \(l,r \rightarrow l-1\)\(r\) 的最长公共前缀的 1 的个数,记为 S
容易发现答案就是所有 \((l-1,r), 1\leq l \leq r \leq n\) 的二进制下 1 的个数和,再减去 2S
形式化的:$$ans=\sum_{l=1}^n \sum_{r=l}^n (F(l-1)+F(r))-2S$$,\(F\) 代表 popcount

先考虑一下 \(S\) 怎么求,可以用数位 dp 解决
\(f(i,0/1,0/1)\) 代表考虑到二进制下第 \(i\) 位,\(l==r\), \(r==n\) 时的最长公共前缀下 1 的个数,\(g(i,0/1,0/1)\) 同理,代表方案数
每次转移的时候就枚举 \(l,r\) 的当前位。如果之前有 \(l==r\) 并且当前位两个都是 1 的话,最长公共前缀 +1,因此当前 dp 值需要加上 \(g(i+1,1,0/1)\) (方案数)
还有一点需要注意,因为实际上要求的区间最接近也是 \([r-1, r]\),因此不存在 \([r,r]\) 的区间,因此 dfs 末尾的时候需要判断如果相等的话方案数应该设为 0
至于前面和式的求法,首先注意到只需要算 \(1..n\) 的 popcount 个数,然后再 \(\times n\) 即可。采用相同的 dp 方式(设一个 \(dp[x][0/1]\) 表示popcount数,dpp表示方案数)

// by SkyRainWind
#include <bits/stdc++.h>
#define mpr make_pair
#define debug() cerr<<"Yoshino\n"
#define pii pair<int,int>
#define pb push_back

using namespace std;

typedef long long ll;
typedef long long LL;

const int inf = 1e9, INF = 0x3f3f3f3f, mod=1e9+7;

ll n;
int bit[65],cnt;
int f[65][2][2], g[65][2][2], dp[65][2], dpp[65][2];

pii dfs0(int x,int isup){
	if(x == cnt+1)return mpr(0, 1);
	int &dd = dp[x][isup];
	if(~dd)return mpr(dd, dpp[x][isup]);
	
	dd = 0;
	int &ng = dpp[x][isup];ng = 0;
	for(int i=0;i<=(isup ? bit[x] : 1);i++){
		pii now = dfs0(x+1, isup && i == bit[x]);
		(dd += now.first) %= mod;
		if(i == 1)(dd += now.second) %= mod;
		(ng += now.second) %= mod;
	}
	return mpr(dd, ng);
}

pii dfs(int x,int iseq,int isup){
	if(x == cnt+1){
		if(!iseq)return mpr(0,1);
		else return mpr(0,0);
	}
	int &dd = f[x][iseq][isup];
	if(~dd)return mpr(dd, g[x][iseq][isup]);
	
	int up = 1;if(isup)up = bit[x];
	int &ng = g[x][iseq][isup];
	dd = ng = 0;
	for(int r=0;r<=up;r++)
		for(int l=0;l<=(iseq?r:1);l++){
			pii now = dfs(x+1, iseq && l==r, isup && r == up);
			(dd += now.first) %= mod;
			if(l==1&&r==1&&iseq)(dd += now.second) %= mod;
			(ng += now.second) %= mod;
		}
	return mpr(dd, ng);
}

void solve(){
	memset(f,-1,sizeof f);
	memset(g,-1,sizeof g);
	memset(dp,-1,sizeof dp);
	memset(dpp,-1, sizeof dpp);
	cin >> n;
	ll t = n;
	cnt = 0;
	while(t){
		bit[++ cnt] = t&1;
		t >>= 1;
	}
	reverse(bit+1,bit+cnt+1);
	pii res0 = dfs0(1, 1);
	pii res = dfs(1, 1, 1);
	cout << (1ll*res0.first*(n%mod)%mod - 2ll*res.first%mod + mod) % mod << '\n';
}

signed main(){
	int te;scanf("%d",&te);
	while(te--)solve();

	return 0;
}

posted @ 2023-02-03 22:58  SkyRainWind  阅读(41)  评论(0编辑  收藏  举报