51nod 1215 数组的宽度
获取每个元素g[i]作为最大值和最小值的区间。假设分别为bl[i],br[i],sl[i],sr[i]。(b和s代表big和small,l、r是left和right)
注:func(i, k, j) 返回的是在[i,j]区间的所有子集中,包含第k个元素的子集数量。
那么对于最终的结果ans,g[i]这个元素贡献的值为g[i] * (func[bl[i],i,br[i] - func(sl[i],i,sr[i])),把每个元素的贡献都加起来即为解。
利用单调栈来求这个范围,要注意一点:
目的就是避免统计的区间的重复。
比如有一个数组{1,3, 3},第一个3的作为最大值的区间故意取[1, 2],第二个取[1, 3]才可,即左端点是<=都要往左走,但是右端点是<关系才走。
不然若都是[1,3],那么对于左边的3而言,会考虑{3,3}这个子集,对于右边的3而言,又会考虑一遍{3,3},形成重复。
#include <stdio.h> #define ll long long const int maxN=5e4+5; int N, M, K, T; ll func(int i, int k, int j) { return (k - i + 1ll) * (j - k + 1ll); } int g[maxN], bl[maxN], br[maxN], sl[maxN], sr[maxN]; int p1, p2, bstk[maxN], sstk[maxN]; int main() { #ifndef ONLINE_JUDGE freopen("data.in", "r", stdin); #endif scanf("%d", &N); for (int i = 1; i <= N; ++i) scanf("%d", &g[i]); p1 = p2 = 0; for (int i = 1; i <= N; ++i) {
// 左端点这里取<=号 while (p1 && g[bstk[p1 - 1]] <= g[i]) --p1; if (!p1) bl[i] = 1; else bl[i] = bstk[p1 - 1] + 1; bstk[p1++] = i; while (p2 && g[sstk[p2 - 1]] >= g[i]) --p2; if (!p2) sl[i] = 1; else sl[i] = sstk[p2 - 1] + 1; sstk[p2++] = i; } p1 = p2 = 0; for (int i = N; i >= 1; --i) {
// 右端点这里取<号 while (p1 && g[bstk[p1 - 1]] < g[i]) --p1; if (!p1) br[i] = N; else br[i] = bstk[p1 - 1] - 1; bstk[p1++] = i; while (p2 && g[sstk[p2 - 1]] > g[i]) --p2; if (!p2) sr[i] = N; else sr[i] = sstk[p2 - 1] - 1; sstk[p2++] = i; } /* for (int i = 1; i <= N; ++i) printf("%d %d\n%d %d\n\n", bl[i], br[i], sl[i], sr[i]); */ ll ans = 0; for (int i = 1; i <= N; ++i) { ans += g[i] * (func(bl[i], i, br[i]) - func(sl[i], i, sr[i])); } printf("%lld\n", ans); return 0; }
上面的代码利于理解,但是这里有一个更简洁,更快速的实现求取区间的方法,建议理解下面这个实现,求出来的都是开区间:
#include <bits/stdc++.h> using namespace std; #define ll long long const int maxN=5e4+5; ll N, g[maxN], bl[maxN], br[maxN], sl[maxN], sr[maxN]; int main() { #ifndef ONLINE_JUDGE freopen("data.in", "r", stdin); #endif scanf("%lld", &N); for (int i = 1; i <= N; ++i) scanf("%lld", &g[i]); for (int i = 1; i <= N; ++i) { bl[i] = sl[i] = i - 1; br[i] = sr[i] = i + 1; } for (int i = 1; i <= N; ++i) { while (bl[i] && g[bl[i]] < g[i]) bl[i] = bl[bl[i]]; while (sl[i] && g[sl[i]] > g[i]) sl[i] = sl[sl[i]]; } for (int i = N; i > 0; --i) { while (br[i] <= N && g[br[i]] <= g[i]) br[i] = br[br[i]]; while (sr[i] <= N && g[sr[i]] >= g[i]) sr[i] = sr[sr[i]]; } ll ans = 0; for (ll i = 1; i <= N; ++i) ans += g[i] * ((i - bl[i]) * (br[i] - i) - (i - sl[i]) * (sr[i] - i)); printf("%lld\n", ans); return 0; }