HDU 6052 To my boyfriend(容斥+单调栈)
题意:对于一个n*m的方格,每个格子中都包含一种颜色,求出任意一个矩形包含不同颜色的期望。
思路:
啊啊啊啊啊,补了两天,总算A了这道题了,简直石乐志,前面的容斥还比较好写,后面的那个>13那个最开始思路错了,然后
竟然只有一组样例没有过???? 然后以为是哪里写挂爆long long了。后来想了好久,明明思路完全就是错的! 最开始想的
是直接找那个值的外围的就好了, 忽略了里面的,然后其实问题是转化成在01矩阵中找全1矩阵的个数,本来兴冲冲的写了一发,
发现和正方形DP不是一个东西。。。。 感觉和求最大1矩阵类似,然后看解法,发现网上的都是n^3的?不过好像这两个本来就
不是一个东西,标程上面的写法看不懂= = ,百度也一堆什么单调栈,暴力排序什么的,感觉和题解的不一样,然后就石乐志。
今天又来老老实实的模拟他的过程,这TM还不是维护一个单调栈????石乐志 石乐志。
感觉单调栈真的厉害呀,栈维护的是一个递增的高度,然后通过b数组来维护当前高度的宽度,然后就可以求得以(i,j)结尾的
全1子矩阵的个数。
官方题解:
每种数可以单独算出其期望然后相加 对于数量小于13的数,可以用容斥的方式来做 对于大于13的数,可以求出全不含的矩阵个数,然后用全部矩阵减去这部分 复杂度 o(n^4/13*T)
代码:
/** @xigua */ #include <stdio.h> #include <cmath> #include <iostream> #include <algorithm> #include <vector> #include <stack> #include <cstring> #include <queue> #include <set> #include <string> #include <map> #include <climits> #define PI acos(-1) using namespace std; typedef long long ll; typedef double db; const int maxn = 1e4 + 5; const ll mod = 1ll<<32; const int INF = 1e8 + 5; const ll inf = 1e15 + 5; const db eps = 1e-6; int mapp[105][105]; struct pos { ll x, y; }; ll get(ll x1, ll x2, ll y1, ll y2) { ll x = x2 - x1 + 1, y = y2 - y1 + 1; return x * (x + 1) / 2 * y * (y + 1) / 2; } ll gao(int val, int n, int m) { ll dp[105][105] = {0}; ll ans = 0; for (int i = 1; i <= n; i++) { ll sum = 0, a[105], b[105], cnt = 0; for (int j = 1; j <= m; j++) { if (mapp[i][j] == val) dp[i][j] = 0; else dp[i][j] = dp[i-1][j] + 1; int tmp = 1; while (cnt && a[cnt] > dp[i][j]) { sum -= a[cnt] * b[cnt]; tmp += b[cnt]; // a维护高度,b维护宽度 cnt--; } cnt++; a[cnt] = dp[i][j]; b[cnt] = tmp; sum += a[cnt] * b[cnt]; ans += sum; } } return ans; } void solve() { ll n, m; cin >> n >> m; vector<pos> g[maxn]; for (int i = 1; i <= n; i++) { for (int j = 1; j <= m; j++) { scanf("%d", mapp[i] + j); g[mapp[i][j]].push_back((pos){i, j}); } } ll tot = 0, all = n * (n + 1) / 2 * m * (m + 1) / 2; //所有矩阵的总数 for (int i = 0; i < n * m; i++) { if (g[i].size() <= 13) { for (int st = 1; st < (1<<g[i].size()); st++) { ll xl = n + 1, xr = 0, yl = m + 1, yr = 0; int num = 0; for (int j = 0; j < g[i].size(); j++) { if ((1<<j) & st) { pos tmp = g[i][j]; num++; xl = min(xl, tmp.x); xr = max(xr, tmp.x); yl = min(yl, tmp.y); yr = max(yr, tmp.y); } } //容斥 if (num & 1) tot += (ll) xl * yl * (n - xr + 1) * (m - yr + 1); else tot -= (ll) xl * yl * (n - xr + 1) * (m - yr + 1); } } else { tot += all - gao(i, n, m); } } printf("%.9f\n", (db)tot / (db)(all)); } int main() { int t = 1, cas = 1; //freopen("in.txt", "r", stdin); // freopen("out.txt", "w", stdout); scanf("%d", &t); while(t--) { // printf("Case %d: ", cas++); solve(); } return 0; }