LOJ #2462. 「2018 集训队互测 Day 1」完美的集合

题目链接

LOJ #2462. 「2018 集训队互测 Day 1」完美的集合

题目大意

有一棵 \(N\) 个点的带权树,树上每个节点有重量 \(w_i\) 和价值 \(v_i\) ,在满足节点重量之和 \(\leq M\) 的集合中,称那些价值之和最大的集合为完美集合。

现在要从所有完美集合中选出 \(K\) 个,要求这 \(K\) 个集合的并集中存在一个点 \(x\),满足对于这些集合中任意一个点 \(y\),都有 \(dist(x,y)\times v_y\leq Max\)。求满足条件的方案数,答案对 \(11920928955078125\) 取模。

\(N\leq 60\)\(M\leq 10000\)\(K,w_i,v_i\leq 10^9\)\(Max\leq 10^{18}\),树边权 \(\leq 10^4\)

Part 1

首先考虑求出完美集合的价值,直接做树上背包的话是 \(O(nm^2)\) 的,瓶颈在于需要合并两个区域的重量之和,要想得到较优的时间复杂度,可以设计一个状态,使得每次只需要决策单点的重量。于是我们先枚举一个点作为根,强制连通块包含该点,然后在树的 \(dfs\) 序上做 \(dp\),设 \(f_{i,j}\) 表示当前考虑到了 \(dfs\) 序上的第 \(i\) 个位置,已选取点重量总和为 \(j\) 时的最大价值和,记 \(x\) 为第 \(i\) 位对应的节点,则有:

  • \(f_{i+siz_x,j}\leftarrow f_{i,j}\)
  • \(f_{i+1,j+w_x}\leftarrow f_{i,j}+v_j\)

若当前点不选,则子树可不能选,从而直接跳到 \(i+siz_x\) 的位置继续 \(dp\),否则往下决策。

时间复杂度 \(O(N^2M)\)

Part 2

对于 \(K\) 个已经选好的完美集合,可以发现满足条件的点 \(x\) 组成了一个连通块(或空集),考虑容斥。对于每一个点,求出当前点可以作为 \(x\) 的选取 \(K\) 个集合方案数,对于每条边,求出两端点同时可以作为 \(x\) 的方案数,可以发现,每个选取方案都会被合法 \(x\) 的连通块中每个点和边都算一次,而在树上,点的数量 \(=\) 边的数量 \(+1\),从而我们拿点的方案数减去边的方案数,这样每个方案就恰好被计算一次了。

\(n\) 为当前限制下可以选取的完美集合数量,显然方案数即为 \(\binom{n}{K}\),于是考虑计算 \(n\) 。这一点和之前的做法是极其相似的,以钦定的点为根(边就任取一端点),\(f_{i,j}\) 改为存二元组 \((num,val)\),表示最大价值和为 \(val\),达到 \(val\) 的集合数量为 \(num\),然后转移和前面基本相同,不过若 \(dp\) 当前点不满足 \(dist(x,y)\times v_y\leq Max\),就强制不选,考虑边时另外一端点强制要选即可。

Part 3

现在得到了 \(n\),我们要求出 \(\binom{n}{K}\)\(11920928955078125\) 取模的结果,可以发现这个质数 \(=5^{23}\),而 \(n\)\(2^N=2^{60}\) 级别的,\(K\)\(10^9\) 级别的,显然无法通过一般的方式计算。

\(\binom{n}{K}\) 换成 \(n!/((n-K)!K!)\),注意到模数一个小质数的大幂次,于是分开计算 \(n!\),并将 \(n!\) 拆解为 \(5\) 的指数和非 \(5\) 的倍数部分,非倍数部分存在逆元,为 \(x^{\phi(5^{23})}\),指数部分直接相减(结果必然为正)然后乘到答案中即可。

\[n!=(\prod_{i=1}^n[i\bot5]i)\cdot(\lfloor\frac{n}{5}\rfloor!\cdot5^{\lfloor\frac{n}{5}\rfloor}) \]

\(\lfloor\frac{n}{5}\rfloor!\) 可以递归求解,考虑求解 \(\prod_{i=1}^n[i\bot5]i\)

设多项式 \(f_n(x)=\prod_{i=1}^n[i\bot 5](x+i)\),我们要求的是 \(f_n(0)\) 。注意到 \(f_{10k}(x)=f_{5k}(x)\cdot f_{5k}(x+5k)\),如果这不明显,让 \(k\)\(5\) 的倍数,就有 \(f_{2k}(x)=f_{k}(x)\cdot f_k(x+k)\),这是一个倍增的关系,容易联想到像快速幂那样,使用二进制拆解求出任意一个 \(5\) 的倍数 \(n\)\(f_n(x)\)

不仅如此,由于我们要求的是 \(f_n(0)\),当 \(x=0\) 时,\(f_k(x+k)\) 的非常数项全是 \(5\) 的倍数,\(x^{23}\) 往后在模意义下便全部为 \(0\),所以计算时只保留 \(f_k(x+k)\) 的前 \(23\) 项便可。 而且 \(f_k(x+k)=\sum a_i(x+k)^i\),我们要求的 \(f_n(0)\) 就是常数项,\(i\geq 23\)\((x+k)^i\) 的常数项变为 \(0\),所以原来的函数 \(f(x)\) 亦可以只保留前 \(23\) 位,这样我们在做多项式乘法时,暴力相乘就是很快的。

具体来说,我们先预处理出 \(f_{5\times 2^i}(x)\) 这些多项式,计算 \(f_n(x)\) 时若 \(n\) 不是 \(5\) 的倍数,就直接算 \(f_{5\lfloor\frac{n}{5}\rfloor}(x)\),然后剩余部分直接乘到算出的 \(f(0)\) 里,然后我们对 \(\frac{n}{5}\) 二进制分解,类似快速幂的步骤,先有一个初始多项式 \(f(x)=1\),遇到 \(\frac{n}{5}\) 中为 \(1\) 的位时将 \(f(x)\) 更新为 \(f_{5\times 2^i}(x)\cdot f(x+5\times 2^i)\),最终 \(f(x)\) 的常数项即为所求答案。

\(P=5^{23}\),预处理是 \(O(\log n\log^2 P)=O(N\log^2P)\) 的,计算一次 \(f_n(0)\) 也是 \(O(N\log^2P)\) 的,求解 \(\lfloor\frac{n}{5}\rfloor!\) 要递归 \(O(\log n)=O(N)\) 层,所以这一部分复杂度是 \(O(N^2\log^2P)\) 的。

 

综上,总时间复杂度为 \(O(N^2M+N^3\log^2P)\)\(1.5e8\) 左右,实际上要快一些。

总结

一道题用到了 \(3\) 个比较重要的技巧:

  • 对于树上连通块的 \(dp\) 可以使用对 \(dfs\)\(dp\) 的方式进行优化。
  • 容斥问题中,若一个方案被一个树形连通块所有位置都计算到,可以使用点边容斥。
  • 计算较大的二项式之类的系数时,可以考虑从多项式和倍增的角度去分析。

Code

LOJ 上的 Hack 数据是完美集合价值为 \(0\) 的情形,此时求完美集合数量时要去掉空集的情况。

#include<iostream>
#include<cstring>
#include<vector>
#define mem(a,b) memset(a, b, sizeof(a))
#define rep(i,a,b) for(int i = (a); i <= (b); i++)
#define per(i,b,a) for(int i = (b); i >= (a); i--)
#define N 62
#define V 10100
#define P 23
#define ll long long
#define lll __int128
#define mod 11920928955078125 // 5^23
#define phi 9536743164062500  // mod - mod/5
#define PLL pair<lll, lll>
#define fr first
#define sc second
using namespace std;

int n, M, K; ll Max;
int w[N], v[N];

int head[N], to[2*N], nxt[2*N];
int siz[N], p[N], dis[N][N];
int cnt, num;

ll Max_V;
PLL f[N][V];
lll C[P+2][P+2];
vector<lll> Pow[N+5];

void init(){ mem(head, -1), cnt = -1; }
void add_e(int a, int b, bool id){
    nxt[++cnt] = head[a], head[a] = cnt, to[cnt] = b;
    if(id) add_e(b, a, 0);
}

void dfs(int x, int fa){
    siz[x] = 1, p[++num] = x;
    for(int i = head[x]; ~i; i = nxt[i]){
        if(to[i] == fa) continue;
        dfs(to[i], x), siz[x] += siz[to[i]];
    }
}

void Trans(PLL &a, PLL b){
    if(a.sc < b.sc) a.fr = b.fr;
    else if(a.sc == b.sc) a.fr += b.fr;
    a.sc = max(a.sc, b.sc);
}
void dp(int root, int ex){
    mem(f, 0);
    f[1][0] = {1, 0};
    rep(i,1,n) rep(j,0,M) if(f[i][j].fr){
        int x = p[i];
        if(x != ex) Trans(f[i+siz[x]][j], f[i][j]);
        if(j+w[x] <= M && (ll)dis[root][x] * v[x] <= Max && (ll)dis[ex][x] * v[x] <= Max)
            Trans(f[i+1][j+w[x]], {f[i][j].fr, f[i][j].sc + v[x]});
    }
}

vector<lll> mul(vector<lll> &A, vector<lll> B){
    if(A.empty() || B.empty()) return {};
    int n = A.size(), m = B.size();
    vector<lll> ret(min(n + m - 1, 23));
    rep(i,0,n-1) rep(j,0,m-1) if(i+j < 23)
        (ret[i+j] += A[i] * B[j]) %= mod;
    return ret;
}

vector<lll> shift(vector<lll> &A, lll k){
    int n = A.size();
    vector<lll> ret(n);
    rep(i,0,n-1){
        lll pow = A[i];
        rep(j,0,i) (ret[i-j] += pow * C[i][j]) %= mod, (pow *= k) %= mod;
    }
    return ret;
}
 
lll qpow(lll a, ll b){
    lll ret = 1;
    for(; b; b >>= 1){ if(b&1) (ret *= a) %= mod; (a *= a) %= mod; }
    return ret;
}

void prework(){
    C[0][0] = 1;
    rep(i,1,P){
        C[i][0] = 1;
        rep(j,1,i) C[i][j] = (C[i-1][j] + C[i-1][j-1]) % mod;
    }
    Pow[0] = {1};
    rep(i,1,4) Pow[0] = mul(Pow[0], {i, 1});
    rep(i,1,N) Pow[i] = mul(Pow[i-1], shift(Pow[i-1], 5ll<<(i-1)));
}

PLL fact(lll n){
    if(n == 0) return {1, 0};
    if(n%5){
        PLL p = fact(n-1);
        return {p.fr * n % mod, p.sc};
    }

    lll exp = n/5, prod = 1;
    PLL p = fact(n/5);
    (prod *= p.fr) %= mod, (exp += p.sc) %= mod;

    vector<lll> poly = {1};
    per(i,N,0) if((n/5)>>i&1) poly = mul(Pow[i], shift(poly, 5ll<<i));
    (prod *= poly[0]) %= mod;
    return {prod, exp};
}

lll cal(lll n, lll m){
    if(m < 0 || n < m) return 0;
    PLL p = fact(n);
    lll val = p.fr, exp = p.sc;
    p = fact(n-m), (val *= qpow(p.fr, phi-1)) %= mod, exp -= p.sc;
    p = fact(m), (val *= qpow(p.fr, phi-1)) %= mod, exp -= p.sc;
    return (val * qpow(5, exp)) % mod;
}

int main(){
    cin>>n>>M>>K>>Max;
    rep(i,1,n) cin>>w[i];
    rep(i,1,n) cin>>v[i];

    init(), mem(dis, 0x3f);
    int a, b, c;
    rep(i,1,n-1){
        cin>>a>>b>>c, add_e(a, b, 1);
        dis[a][b] = c, dis[b][a] = c;
    }
    rep(i,1,n) dis[i][i] = 0;
    rep(k,1,n) rep(i,1,n) rep(j,1,n) 
        dis[i][j] = min(dis[i][j], dis[i][k] + dis[k][j]);
    rep(i,1,n) dis[i][0] = dis[0][i] = 0;

    rep(i,1,n){
        num = 0, dfs(i, 0);
        dp(0, 0);
        rep(j,0,M) Max_V = max(Max_V, (ll)f[n+1][j].sc);
    }

    prework();
    lll ans = 0;
    rep(i,1,n+(cnt+1)/2){
        num = 0;
        if(i <= n) dfs(i, 0), dp(i, 0);
        else dfs(to[(i-n-1)*2], 0), dp(to[(i-n-1)*2], to[(i-n-1)*2+1]);
        lll num = 0;
        rep(j,0,M) if(f[n+1][j].sc == Max_V) num += f[n+1][j].fr - (Max_V == 0 && j == 0);
        (ans += i <= n ? cal(num, K) : mod - cal(num, K)) %= mod;
    }
    cout<< (ll)ans <<endl;
    return 0;
}
posted @ 2022-01-02 21:25  Neal_lee  阅读(162)  评论(0编辑  收藏  举报