CodeChef BB Billboards 题解(LGV 引理应用)

https://www.cnblogs.com/ruierqwq/p/young-tableau.html 看到了这题,于是去做了一下。

题意:有一条路径上有连续排列的 \(N\) 个位置,每个位置可以放或者不放广告牌。要求每连续的 \(M\) 个位置至少有\(K\) 个放了广告牌。求满足要求,且放最少的广告牌的方案数。\(N \le 10^9, K \le M \le 50\)

然而这题跟杨表关系不大。不过也还有意思。

首先构造出最小的一个方案,(使用贪心)不难发现如果按 \(M\) 为一行排成网格,最优方案一定是在每行的最后视需要放一些(也就是,除了最后一行外放 \(K\) 个;最后一行在横坐标 \(> M - \text{rem}\) 的位置放最多 \(K\) 个广告牌)(其中 \(\text{rem}\) 表示最后一行多余的几个空,即 \((m - (n \bmod m)) \bmod m\))。

接下来考虑什么样的方案是最优方案。不妨把最后一行补满,此时最后一行可能有不超过 \(K\) 个固定的必须放广告牌的位置。补满后最优方案每一行都有一定有 \(K\) 个广告牌,记总共有 \(R\) 行,\(i\) 行第 \(j\) 个位置的横坐标为 \(a_{i, j}\)。则可以发现是合法方案的条件可以转写为:

\[\begin{align*} &a_{i, j} < a_{i, j + 1} & (j < k)\\ &a_{i, j} \ge a_{i + 1, j} & (i < r) \\ &a_{r, j} = M - \text{rem} + j & (\mathop{\rm if} \text{rem} \ge K) \\ &a_{r, j} = M - K + j & (j > K - \text{rem}, \mathop{\rm if} \text{rem} < K) \end{align*} \]

这个有点像杨表,但不完全是杨表。考虑怎么计数。发现把这 \(NK\) 个数写下来,用一些平移操作(\(a_{i, j} \gets a_{i, j} - j\))把限制写为全 \(\ge\),就可以变为网格图不相交路径计数,可以使用 LGV 引理解决(我是不是把最细节的部分省略掉了 QwQ)。这部分其实不是特别干净,因为最后一条限制写成对应的端点可能不是那么显然,所以并不是非常好写下来;不过这类题型本身可以去找找 LGV 的博客看看,例如 Monotonic Matrix 这种题。

时间复杂度 \(O(m^3)\)

代码实现(PyPy3 可通过):

from typing import List


MOD = 10**9 + 7


def inv(x):
    return pow(x, MOD - 2, MOD)


def comb(n, m):
    if m > n or m < 0:
        return 0
    ans = 1
    for i in range(1, m + 1):
        ans = ans * (n - i + 1) * inv(i) % MOD
    return ans


def determinant(mat: List[List[int]]):
    n = len(mat)
    ans = 1
    for i in range(n):
        row_pivot = -1
        for j in range(i, n):
            if mat[j][i] != 0:
                row_pivot = j
        if row_pivot == -1:
            return 0
        if i != row_pivot:
            ans = -ans
            mat[i], mat[row_pivot] = mat[row_pivot], mat[i]
        ans = ans * mat[i][i] % MOD
        for j in range(i + 1, n):
            factor = -mat[j][i] * inv(mat[i][i]) % MOD
            if mat[j][i] != 0:
                for k in range(i, n):
                    mat[j][k] = (mat[j][k] + factor * mat[i][k]) % MOD
    return ans


# print(determinant([[2, -1, 0], [-1, 2, -1], [0, -1, 2]]))

for _ in range(int(input())):
    n, m, k = map(int, input().split())
    if m == k:
        print(1)
        continue
    rem = (m - n % m) % m
    last_row = [-1] * k
    if rem >= k:
        last_row = [m + i - rem for i in range(k)]
    else:
        last_row = [i for i in range(k - rem)] + [m + i - rem for i in range(rem)]
    last_row = [x - i for i, x in enumerate(last_row)]
    # print(last_row)
    max_num = m - k
    row = (n + m - 1) // m
    end_points = [(-1, -1)] * (max_num)
    pnt = 0
    for i in range(max_num):
        while pnt < k and last_row[pnt] <= i:
            pnt += 1
        end_points[i] = row - (pnt > k - min(rem, k)), pnt
    # print(end_points)
    # 为了满足不交限制,平移第 i 个起点到 (max_num - 1 - i, i)
    start_points = [(max_num - 1 - i, i) for i in range(max_num)]
    end_points = [(x + max_num - 1 - i, y + i) for i, (x, y) in enumerate(end_points)]
    mat = [[0] * (max_num) for _ in range(max_num)]
    for i in range(max_num):
        for j in range(max_num):
            x0, y0 = start_points[i]
            x1, y1 = end_points[j]
            mat[i][j] = comb(x1 - x0 + y1 - y0, y1 - y0)

    # print(start_points)
    # print(end_points)
    # print(mat)

    print(determinant(mat))

posted @ 2024-12-16 00:49  cccpchenpi  阅读(16)  评论(0编辑  收藏  举报