CF1402A 漂亮的篱笆

1 CF1402A 漂亮的篱笆

2 题目描述

大家都知道,\(Balázs\) 有全城最漂亮的篱笆。它是由地上的 \(𝑁\) 个漂亮的长方形紧挨在一起组成的。第 \(i\) 段的高度是 \(ℎ_𝑖\),宽度是 \(𝑤_𝑖\) (长宽都是整数)。我们要在这个漂亮的篱笆上找出漂亮的长方形。漂亮的长方形定义如下:

  • 长方形的侧面要么是水平的,要么是垂直的,长度是整数
  • 长方形与地面之间的距离是整数
  • 长方形与左侧第一段之间的距离是整数
  • 长方形完全位于区块内部

请问漂亮长方形的个数是多少? 这个数字可能很大,所以结果需要对 \(10^9 + 7\) 取模。

3 题解

这里把输入的 \(n\) 个长方形叫做区块。

我们发现:长方形分两种类型,一种是完全在一个区块内的,一种是跨越至少两个区块的。对于在同一区块内的长方形,我们可以通过两个在区块底面的点与两个在区块左面的点来确定。这两个底面的点之间的线段就是长方形的底,两个左面的点之间的线段就是长方形的高,由此,对于一个长为 \(n\),宽为 \(m\) 的区块,其内部的长方形的个数就是 \(C_{n+1}^2 + C_{m+1}^2\)。这里我们定义函数 \(f(x) = C_{x+1} ^ 2\)

对于跨区块的长方形,我们可以枚举出它的起点区块和终点区块以及其中间经过的最矮的区块的高度。即:

\[\sum _{1 \le l < r \le n} w_l * w_r * f(min_{i=l}^{r} h_i) \]

但是计算这个式子需要 \(O(n^2)\) 的时间复杂度,于是我们考虑优化。

我们可以考虑以每一个区块 \(i\) 的高度 \(h_i\) 为最小高度时左边和右边最多能达到的区块 \(l_i\)\(r_i\)(即如果左边或右边再延伸就会导致 \(h_i\) 不是最小高度)。这时的答案为:

\[\sum_{i = 1}^n (\sum_{j = l_i} ^ {i-1} w_j * \sum_{j = i+1}^{r_i} w_j * f(h_i) + w_i * \sum_{j = i+1}^{r_i} w_j * f(h_i) + \sum_{j = l_i} ^ {i-1} w_j * w_i * f(h_i)) \]

而求出左边和右边最多能达到的区块编号,我们可以利用单调栈进行加速。

4 代码(空格警告):

#include <iostream>
#include <queue>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 2e6 + 10;
const int mod = 1e9+7;
#define int long long
int T, n, m, u, v, tot, root, ans, cnt;
int head[N], ver[N], last[N], size[N];
int p[N], w[N];
priority_queue<int> q;
bool cmp(int x, int y)
{
    return x > y;
}
void add(int x, int y)
{
    ver[++tot] = y;
    last[tot] = head[x];
    head[x] = tot;
}
void dfs(int x, int fa)
{
    size[x] = 1;
    for (int i = head[x]; i; i = last[i])
    {
        int y = ver[i];
        if (y == fa) continue;
        dfs(y, x);
        size[x] += size[y];
    }
}
void dfs1(int x, int fa)
{
    for (int i = head[x]; i; i = last[i])
    {
        int y = ver[i];
        if (y == fa) continue;
        q.push(size[y] * (n - size[y]));
        dfs1(y, x);
    }
}
signed main()
{
    cin >> T;
    while (T--)
    {
        ans = 0;
        cnt = 0;
        tot = 0;
        memset(head, 0, sizeof(head));
        memset(ver, 0, sizeof(ver));
        memset(last, 0, sizeof(last));
        cin >> n;
        for (int i = 1; i < n; i++)
        {
            cin >> u >> v;
            add(u, v);
            add(v, u);
        }
        cin >> m;
        for (int i = 1; i <= m; i++) cin >> p[i];
        sort(p+1, p+m+1, cmp);
        root = 1;
        dfs(root, 0);
        dfs1(root, 0);
        if (m < n-1)
        {
            for (int i = 1; i <= m; i++) w[i] = p[i];
            for (int i = m+1; i < n; i++) w[i] = 1;
        }
        if (m > n-1)
        {
            w[1] = 1;
            for (int i = 1; i <= m-n+2; i++) w[1] = (w[1] * p[i]) % mod;
            for (int i = 2; i <= n-1; i++) w[i] = p[m-n+i+1];
        }
        if (m == n-1) for (int i = 1; i < n; i++) w[i] = p[i];
        while (q.size())
        {
            ans = (ans + (q.top() * w[++cnt]) % mod) % mod;
            q.pop();
        }
        cout << ans << endl;
    }
    return 0;
}

欢迎关注我的公众号:智子笔记

posted @ 2021-02-05 15:40  David24  阅读(74)  评论(0编辑  收藏  举报