bzoj4584 [Apio2016]赛艇
传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=4584
【题解】
令f[i,j,k]表示前i个学校,赛艇最远放在j区间,且j这个区间放了k个赛艇。
那么显然区间可以离散(这里用左闭右开方便),那么就是一个大概O(n^3)的做法了。
好像就行了?据说还要常数优化qwq
丢个方程
令s[j]表示Σf[i-1,j',*],其中j'<=j。
令len[j]表示第j个区间的实际长度。
f[i,j,1] = f[i-1,j,k] + s[j]*len[j]
(前i-1个就放到这么多了,前i-1个还没放到这,现在放在这里,有len[j]种方案,因为放哪都行)
f[i,j,k] = f[i-1,j,k] + f[i-1,j,k-1]*(len[j]-k+1)/k
(前i-1个就放到这么多了,前i-1个这个区间还差一个,把他塞进来)
后面那坨是怎么回事呢,因为原来是C(len[j],k-1),现在变成C(len[j],k),改变的量。
# include <vector> # include <stdio.h> # include <string.h> # include <algorithm> // # include <bits/stdc++.h> using namespace std; typedef long long ll; typedef long double ld; typedef unsigned long long ull; const int M = 5e2 + 10, N = 1000 + 10; const int mod = 1e9+7; # define RG register # define ST static int n, m; struct intervals { // [l, r) int l, r; intervals() {} intervals(int l, int r) : l(l), r(r) {} }p[M]; vector<int> ps; int len[N]; int f[N][M], s[N]; int cnt[M][N], inv[N]; inline int pwr(int a, int b) { int ret = 1; while(b) { if(b&1) ret = 1ll * ret * a % mod; a = 1ll * a * a % mod; b >>= 1; } return ret; } int main() { scanf("%d", &n); for (int i=1; i<=n; ++i) { scanf("%d%d", &p[i].l, &p[i].r); ++p[i].r; ps.push_back(p[i].l), ps.push_back(p[i].r); } sort(ps.begin(), ps.end()); ps.erase(unique(ps.begin(), ps.end()), ps.end()); m = ps.size(); for (int i=1; i<=n; ++i) { int L = lower_bound(ps.begin(), ps.end(), p[i].l)-ps.begin()+1; int R = lower_bound(ps.begin(), ps.end(), p[i].r)-ps.begin()+1; p[i] = intervals(L, R); } for (int i=1; i<m; ++i) len[i] = ps[i] - ps[i-1]; for (int i=0; i<=1000; ++i) inv[i] = pwr(i, mod-2); for (int i=0; i<m; ++i) s[i] = 1; for (int i=1; i<=n; ++i) { for (int j=1; j<m; ++j) cnt[i][j] = cnt[i-1][j]; for (int j=p[i].l; j<p[i].r; ++j) ++cnt[i][j]; } for (int i=1; i<=n; ++i) { for (int j=p[i].l; j<p[i].r; ++j) { for (int k=cnt[i][j]; k>=2; --k) { f[j][k] = f[j][k] + 1ll * f[j][k-1] * (len[j]-k+1) % mod * inv[k] % mod; if(f[j][k] >= mod) f[j][k] -= mod; } f[j][1] = f[j][1] + 1ll * s[j-1] * len[j] % mod; if(f[j][1] >= mod) f[j][1] -= mod; } s[0] = 1; for (int j=1; j<m; ++j) { s[j] = s[j-1]; for (int k=1; k<=cnt[i][j]; ++k) { s[j] = s[j] + f[j][k]; if(s[j] >= mod) s[j] -= mod; } } } int ans = 0; for (int i=1; i<m; ++i) for (int j=1; j<=n; ++j) { ans = ans + f[i][j]; if(ans >= mod) ans -= mod; } printf("%d\n", ans); return 0; }