[hihoCoder#1596]Beautiful Sequence
[hihoCoder#1596]Beautiful Sequence
试题描述
对于一个正整数列 \(a[1], ... , a[n] (n \ge 3)\),如果对于所有 \(2 \le i \le n - 1\),都有 \(a[i-1] + a[i+1] \ge 2 \times a[i]\),则称这个数列是美丽的。
现在有一个正整数列 \(b[1], ..., b[n]\),请计算:将 \(b\) 数列均匀随机打乱之后,得到的数列是美丽的概率 \(P\)。
你只需要输出 \((P \times (n!))\ mod\ 1000000007\) 即可。(显然 \(P \times (n!)\) 一定是个整数)
输入
第一行一个整数 \(n\)。\((3 \le n \le 60)\)
接下来 \(n\) 行,每行一个整数 \(b[i]\)。\((1 \le b[i] \le 1000000000)\)
输出
输出 \((P \times (n!))\ mod\ 1000000007\)。
输入示例
4
1
2
1
3
输出示例
8
数据规模及约定
见“输入”
题解
美丽的序列一定是一个下凸壳,考虑从两边往中间加,记 \(f(l_1, l_2, r_1, r_2)\) 表示靠内侧的左侧两个数分别为 \(a[l_1]\)、\(a[l_2]\),右侧两个数分别为 \(b[r_1]\)、\(b[r_2]\) 的方案数。
我的代码当时瞎写的,状态数看上去 \(O(n^6)\) 但记忆化搜索加剪枝跑得比 \(O(n^4)\) 还快。。。
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
#include <map>
using namespace std;
int read() {
int x = 0, f = 1; char c = getchar();
while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); }
while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); }
return x * f;
}
#define maxn 61
#define MOD 1000000007
#define LL long long
struct Sta {
int l, r, _1, _2, _3, _4;
Sta() {}
Sta(int _l, int _r, int __1, int __2, int __3, int __4): l(_l), r(_r), _1(__1), _2(__2), _3(__3), _4(__4) {}
bool operator < (const Sta& t) const {
if(l != t.l) return l < t.l;
if(r != t.r) return r < t.r;
if(_1 != t._1) return _1 < t._1;
if(_2 != t._2) return _2 < t._2;
if(_3 != t._3) return _3 < t._3;
if(_4 != t._4) return _4 < t._4;
return 0;
}
} ;
map <Sta, int> f;
int n, num[maxn], fac[maxn];
int dp(Sta st) {
int l = st.l, r = st.r, _1 = st._1, _2 = st._2, _3 = st._3, _4 = st._4;
int i = l + r + 1;
if(num[i] == num[n]) {
if(i - _2 && i - _3 && num[i-_2] + num[i-_3] < (num[i] << 1)) return 0;
if(i - _1 && num[i-_1] + num[i] < (num[i-_2] << 1)) return 0;
if(i - _4 && num[i-_4] + num[i] < (num[i-_3] << 1)) return 0;
return 1;
}
if(f.count(st)) return f[st];
int& ans = f[st];
ans = 0;
int lim;
if(!(i - _1)) lim = 0;
else lim = (num[i-_2] << 1) - num[i-_1];
if(num[i] >= lim && (i == 1 || num[i] != num[i-1])) {
ans += dp(Sta(l + 1, r, _2 + 1, 1, _3 + 1, _4 + 1));
if(ans >= MOD) ans -= MOD;
}
if(!(i - _4)) lim = 0;
else lim = (num[i-_3] << 1) - num[i-_4];
if(num[i] >= lim && num[i] != num[i+1]) {
ans += dp(Sta(l, r + 1, _1 + 1, _2 + 1, 1, _3 + 1));
if(ans >= MOD) ans -= MOD;
}
// printf("f(%d, %d, [%d], %d | %d | %d | %d) = %d\n", l, r, i, num[i-_1], num[i-_2], num[i-_3], num[i-_4], ans);
return ans;
}
int main() {
n = read();
fac[0] = 1;
for(int i = 1; i <= n; i++) num[i] = read(), fac[i] = (LL)fac[i-1] * i % MOD;
sort(num + 1, num + n + 1);
for(int i = 1; i <= (n >> 1); i++) swap(num[i], num[n-i+1]);
// for(int i = 1; i <= n; i++) printf("%d%c", num[i], i < n ? ' ' : '\n');
for(int i = 1; i < n - 1; i++)
if(num[i] == num[i+1] && num[i+1] == num[i+2] && num[i] != num[n]) return puts("0"), 0;
LL ans = dp(Sta(0, 0, 1, 1, 1, 1));
// printf("%lld\n", ans);
num[n+1] = -1; int tmp = 1;
for(int i = 2; i <= n + 1; i++)
if(num[i] != num[i-1]) (ans *= fac[tmp]) %= MOD, tmp = 1;
else tmp++;
printf("%lld\n", ans);
return 0;
}