星空 (Easy version & Hard Version) 题解
星空 (Easy version & Hard Version) 题解
不知道简单版有没有单独的做法,反正我不会
很明显如果 \(a\) 中有大于 \(x\) 的数直接无解,输出 \(0\)。
发现每个 \(a_i\) 都是 \(2\) 的整数次幂,这告诉我们每个 \(a_i\) 在二进制表示下只会有一位上是 \(1\),那么,相邻的两个数相加,最多就是进一个位。
然后我们来考虑 \(x\)。假如 \(x\) 的最高位 \(1\) 和次高位 \(1\) 分别在 \(i\) 位和 \(j\) 位上。由于没有大于 \(x\) 的数,所以现在 \(a\) 中最大的数也不会超过 \(2^i\)。我们来考虑这些数怎么放是合法的:
- 对于等于 \(2^i\) 的数,必须分开;
- 其他数彼此可以相邻,因为即使是 \(2^{i-1} + 2^{i-1} = 2^i\) 也小于等于 \(x\);
- 然后来考虑 \(2^i\) 和其他数的关系,发现只需要 \(2^i\) 与其他数中大于 \(2^j\) 的数不相邻即可。
这样就可以利用插板法来解决问题了。设等于 \(2^i\) 的数有 \(s_1\) 个,小于 \(2^i\),大于 \(2^j\) 的数有 \(s^2\) 个,剩下的数有 \(s_3\) 个。以下分别简称位第一、二、三类数。
首先这些数内部可以随便排列,所以先有 \(s_1 !s_2 ! s_3 !\);
然后考虑把第一类数向第三类数里插入,可以插入的位置有 \(s_3+1\) 个,所以方案数为 \(s_3+1 \choose s_1\);
最后考虑把第二类数往序列里插入。因为第二类数彼此可以相邻,而不可以与第一类数相邻,所以只需要把第一类数占据的位置去掉,考虑往剩下的位置插入。这时问题就变成了向 \(s_3 + 1 - s_1\) 个有编号的桶中放入 \(s_2\) 个无编号的球,桶可以空,求方案数。插板法经典问题,方案为 \(s_2 + s_3 - s_1 \choose s_3 - s_1\)。
所以最后的答案就是 \(s_1 !s_2 ! s_3 ! {s_3+1 \choose s_1} {s_2 + s_3 - s_1 \choose s_3 - s_1}\)。
注意求组合数的时候,即使 \(n\) 和 \(m\) 都小于 \(0\),只要 \(n = m\),结果也要是 \(1\)(因为这个吃了一发罚时 xwx)。
代码:
#include<bits/stdc++.h>
namespace IO2{
int read(){
int x = 0, f = 1; char ch = getchar();
while(ch<'0' || ch>'9') {if(ch == '-') f = -1; ch = getchar();}
while(ch>='0'&&ch<='9') {x = x * 10 + ch - 48, ch = getchar();}
return x * f;
}
}
using IO2::read;
#define ll long long
using namespace std;
const int N = 1e5+100;
int fac[N], inv[N];
ll n, X;
ll a[N];
const int mod = 1e9+7;
int C(int tn, int tm) {
if(tn == tm) return 1;
if(tn < 0 || tm < 0) return 0;
if(tn < tm) return 0;
return 1ll*fac[tn] * inv[tm]%mod * inv[tn-tm]%mod;
}
int fpow(int a, int b) {
a%=mod;
int ret = 1;
while(b) {
if(b & 1) {
ret = 1ll*ret*a%mod;
}
b>>=1;
a = 1ll*a*a%mod;
}
return ret;
}
int fir, sec;
int hi[N];
int main() {
scanf("%lld%lld", &n, &X);
for(int i = 1; i<=n; ++i) {
scanf("%lld", &a[i]);
for(int j = 63; j>=1; --j) {
if((a[i] >> (j-1)) & 1) {
hi[i] = j;
break;
}
}
}
fac[0] = 1;
for(int i = 1; i<=n + 2; ++i) {
fac[i] = 1ll*fac[i-1] * i%mod;
}
inv[n + 2] = fpow(fac[n + 2], mod-2);
for(int i = n+1; i>=0; --i) {
inv[i] = 1ll*inv[i+1]*(i+1)%mod;
}
for(int i = 1; i<=n; ++i) {
if(a[i] > X) {
puts("0");
return 0;
}
}
for(int i = 63; i>=1; --i) {
if((X >> (i-1)) & 1) {
if(!fir) {
fir = i;
} else if(!sec) {
sec = i;
}
}
}
int cnt1 = 0, cnt2 = 0, cnt3 = 0;
for(int i = 1; i<=n; ++i){
if(hi[i] == fir) {
++cnt1;
} else if(hi[i] > sec) {
++cnt2;
} else ++cnt3;
}
int ans = 1ll*fac[cnt1] * fac[cnt2]%mod * fac[cnt3]%mod * C(cnt3+1, cnt1) %mod * C(cnt2 + cnt3 - cnt1, cnt3 - cnt1)%mod;
printf("%d\n", ans);
return 0;
}