ZROJ87 树状数组 - 数位dp -
题目链接:http://zhengruioi.com/problem/87
题解:
首先考虑 \(f(l,r)\) 代表什么,官方题解很详细了就不再赘述了:
因此我们要求的就是对于所有 \(l,r \rightarrow l-1\) 和 \(r\) 的最长公共前缀的 1 的个数,记为 S
容易发现答案就是所有 \((l-1,r), 1\leq l \leq r \leq n\) 的二进制下 1 的个数和,再减去 2S
形式化的:$$ans=\sum_{l=1}^n \sum_{r=l}^n (F(l-1)+F(r))-2S$$,\(F\) 代表 popcount
先考虑一下 \(S\) 怎么求,可以用数位 dp 解决
设 \(f(i,0/1,0/1)\) 代表考虑到二进制下第 \(i\) 位,\(l==r\), \(r==n\) 时的最长公共前缀下 1 的个数,\(g(i,0/1,0/1)\) 同理,代表方案数
每次转移的时候就枚举 \(l,r\) 的当前位。如果之前有 \(l==r\) 并且当前位两个都是 1 的话,最长公共前缀 +1,因此当前 dp 值需要加上 \(g(i+1,1,0/1)\) (方案数)
还有一点需要注意,因为实际上要求的区间最接近也是 \([r-1, r]\),因此不存在 \([r,r]\) 的区间,因此 dfs 末尾的时候需要判断如果相等的话方案数应该设为 0
至于前面和式的求法,首先注意到只需要算 \(1..n\) 的 popcount 个数,然后再 \(\times n\) 即可。采用相同的 dp 方式(设一个 \(dp[x][0/1]\) 表示popcount数,dpp表示方案数)
// by SkyRainWind
#include <bits/stdc++.h>
#define mpr make_pair
#define debug() cerr<<"Yoshino\n"
#define pii pair<int,int>
#define pb push_back
using namespace std;
typedef long long ll;
typedef long long LL;
const int inf = 1e9, INF = 0x3f3f3f3f, mod=1e9+7;
ll n;
int bit[65],cnt;
int f[65][2][2], g[65][2][2], dp[65][2], dpp[65][2];
pii dfs0(int x,int isup){
if(x == cnt+1)return mpr(0, 1);
int &dd = dp[x][isup];
if(~dd)return mpr(dd, dpp[x][isup]);
dd = 0;
int &ng = dpp[x][isup];ng = 0;
for(int i=0;i<=(isup ? bit[x] : 1);i++){
pii now = dfs0(x+1, isup && i == bit[x]);
(dd += now.first) %= mod;
if(i == 1)(dd += now.second) %= mod;
(ng += now.second) %= mod;
}
return mpr(dd, ng);
}
pii dfs(int x,int iseq,int isup){
if(x == cnt+1){
if(!iseq)return mpr(0,1);
else return mpr(0,0);
}
int &dd = f[x][iseq][isup];
if(~dd)return mpr(dd, g[x][iseq][isup]);
int up = 1;if(isup)up = bit[x];
int &ng = g[x][iseq][isup];
dd = ng = 0;
for(int r=0;r<=up;r++)
for(int l=0;l<=(iseq?r:1);l++){
pii now = dfs(x+1, iseq && l==r, isup && r == up);
(dd += now.first) %= mod;
if(l==1&&r==1&&iseq)(dd += now.second) %= mod;
(ng += now.second) %= mod;
}
return mpr(dd, ng);
}
void solve(){
memset(f,-1,sizeof f);
memset(g,-1,sizeof g);
memset(dp,-1,sizeof dp);
memset(dpp,-1, sizeof dpp);
cin >> n;
ll t = n;
cnt = 0;
while(t){
bit[++ cnt] = t&1;
t >>= 1;
}
reverse(bit+1,bit+cnt+1);
pii res0 = dfs0(1, 1);
pii res = dfs(1, 1, 1);
cout << (1ll*res0.first*(n%mod)%mod - 2ll*res.first%mod + mod) % mod << '\n';
}
signed main(){
int te;scanf("%d",&te);
while(te--)solve();
return 0;
}