[BZOJ3745][Coci2015]Norma
[BZOJ3745][Coci2015]Norma
试题描述
输入
第1行,一个整数N;
第2~n+1行,每行一个整数表示序列a。
输出
输出答案对10^9取模后的结果。
输入示例
4 2 4 1 4
输出示例
109
数据规模及约定
N <= 500000
1 <= a_i <= 10^8
题解
分治,然后分类讨论,考虑最大/最小值在左边还是在右边。
#include <iostream> #include <cstdio> #include <cstdlib> #include <cstring> #include <cctype> #include <algorithm> using namespace std; int read() { int x = 0, f = 1; char c = getchar(); while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); } while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); } return x * f; } #define maxn 500010 #define MOD 1000000000 #define oo 2147483647 #define LL long long int n, A[maxn], sl[maxn], smx[maxn], smn[maxn], slmx[maxn], slmn[maxn], sm[maxn], slm[maxn]; int solve(int l, int r) { if(l == r) return (LL)A[l] * A[l] % MOD; int mid = l + r >> 1, ans = solve(l, mid) + solve(mid + 1, r); if(ans >= MOD) ans -= MOD; sl[mid] = smx[mid] = smn[mid] = slmx[mid] = slmn[mid] = sm[mid] = slm[mid] = 0; int mx = 0, mn = oo; for(int i = mid + 1; i <= r; i++) { sl[i] = sl[i-1] + i - mid; if(sl[i] >= MOD) sl[i] -= MOD; mx = max(mx, A[i]); mn = min(mn, A[i]); smx[i] = smx[i-1] + mx; if(smx[i] >= MOD) smx[i] -= MOD; smn[i] = smn[i-1] + mn; if(smn[i] >= MOD) smn[i] -= MOD; slmx[i] = slmx[i-1] + (LL)(i - mid) * mx % MOD; if(slmx[i] >= MOD) slmx[i] -= MOD; slmn[i] = slmn[i-1] + (LL)(i - mid) * mn % MOD; if(slmn[i] >= MOD) slmn[i] -= MOD; sm[i] = sm[i-1] + (LL)mx * mn % MOD; if(sm[i] >= MOD) sm[i] -= MOD; slm[i] = slm[i-1] + (LL)(i - mid) * mx % MOD * mn % MOD; if(slm[i] >= MOD) slm[i] -= MOD; } mx = 0; mn = oo; int Sl = 0, Smx = 0, Smn = 0, Slmx = 0, Slmn = 0, Sm = 0, Slm = 0, mntr = mid + 1, mxtr = mid + 1; for(int i = mid; i >= l; i--) { mx = max(mx, A[i]); mn = min(mn, A[i]); Sl = mid - i + 1; Smx = mx; Smn = mn; Slmx = (LL)(mid - i + 1) * mx % MOD; Slmn = (LL)(mid - i + 1) * mn % MOD; Sm = (LL)mx * mn % MOD; Slm = (LL)(mid - i + 1) * mx % MOD * mn % MOD; while(mntr <= r && A[i] < A[mntr]) mntr++; while(mxtr <= r && A[i] > A[mxtr]) mxtr++; int tmp = min(mntr, mxtr); ans += ((LL)Slm * (tmp - mid - 1) + (LL)Sm * (sl[tmp-1] - sl[mid] + MOD)) % MOD; if(ans >= MOD) ans -= MOD; tmp = max(mntr, mxtr); ans += ((LL)Sl * (sm[r] - sm[tmp-1] + MOD) + slm[r] - slm[tmp-1] + MOD) % MOD; if(ans >= MOD) ans -= MOD; if(mntr < mxtr) { ans += ((LL)Slmx * (smn[mxtr-1] - smn[mntr-1] + MOD) + (LL)Smx * (slmn[mxtr-1] - slmn[mntr-1] + MOD)) % MOD; if(ans >= MOD) ans -= MOD; } else { ans += ((LL)Slmn * (smx[mntr-1] - smx[mxtr-1] + MOD) + (LL)Smn * (slmx[mntr-1] - slmx[mxtr-1] + MOD)) % MOD; if(ans >= MOD) ans -= MOD; } // printf("in[%d, %d] %d %d: %d tr %d & %d\n", l, r, i, r, ans, mntr, mxtr); } // printf("[%d, %d] = %d\n", l, r, ans); return ans; } int main() { n = read(); for(int i = 1; i <= n; i++) A[i] = read(); printf("%d\n", solve(1, n)); return 0; }