题解 QOJ9300【So Many Possibilities...】/ NOD2210D【nnntxdy】

problem

txdy 的 nnn 跑去打炉石,发现对面很菜,所以想要计算自己一发能干掉几个。

对方有 n 个随从,生命值分别是 a[1],a[2],...,a[n]。txdy 的 nnn 发动了一次技能,会连续攻击 m 次,每次在当前生命值为正的随从中随机选择一个,使得它的生命值减一。当一个随从的生命值降到 0 时它就死掉了。

问期望能干掉几个随从,答案模998244353输出。对于100%的数据,n≤15,m≤200,a[i]≤200,保证不会出现没有可攻击的人的情况。

solution

考虑 DP 刻画一下当前干掉人的状态:\(f(i,S)\) 表示已经攻击 \(i\) 次,已经干掉了 \(S\) 中的人,\(S\) 是集合。向 \(i+1\) 转移时,首先可以转移到 \(f(i+1,S)\),乘上概率 \(\frac1{n-|S|}\),暂时不考虑选了谁。然后假如这一步干掉了一个人 \(j\),满足 \(j\not\in S\),然后转移到 \(f(i+1,S\cup\{j\})\),乘上概率 \(\frac1{n-|S|}\),乘上 \(\binom{i-\sum_{k\in S}a_k}{a_j-1}\) 这里实际上是 \(\binom{tot-1}{a_j-1}\) 的形式,因为我们钦定最后这一步是杀 \(j\) 的,然后选择之前的哪些步用于打 \(j\),就是这样,概率和方案数是分开计算的,而且正确。\(O(2^nnm)\)

然后到了最后 \(f(m,S)\) 要乘上剩下人的安排还有 \(|S|\) 以计入答案,这剩下人的安排比较烦,记 \(\ell=m-\sum_{i\in S}a_i\),那么我们剩下人的个数是一个背包,\(\sum_{\{c|0\leq c_i<a_i\}}\binom{\ell}{c_1,c_2,\cdots}\) 这样的形式,将关于 \(c_i\) 的阶乘倒数拆开变成多项式(背包),然后就是要计算 \(bag(S)\) 表示子集 \(S\) 的所有人用 \((1+x+\frac{x^2}{2}+\frac{x^3}{3!}+\cdots+\frac{x^{a_i-1}}{(a_i-1)!})\) 的生成函数卷起来然后取 \([x^\ell]\) 项系数点乘 \(\ell!\),然后显然可以计算。暴力背包是 \(O(2^nm^2)\),执意用 NTT 优化背包是 \(O(2^nm\log m)\),但不如 meet in the middle,把子集拆成两半,询问只关心一个点值,所以复杂度是 \(O(2^{n/2}m^2+2^nm)\) 就能把复杂度瓶颈变成前面的 DP 部分。

code

点击查看代码
#include <cstdio>
#include <vector>
#include <cstring>
#include <cassert>
#include <algorithm>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr,##__VA_ARGS__)
#else
#define debug(...) void(0)
#endif
#define popcount __builtin_popcount
typedef long long LL;
const int P=998244353,G=3;
LL inv[1<<17];
auto __RobinChen__=[](){
    inv[1]=1;
    for(int i=2;i<=1e5;i++) inv[i]=(P-P/i)*inv[P%i]%P;
    return 0;
}();
int getbit(int k){return 1<<k;}
bool conbit(int x,int k){return x&getbit(k);}
int lowbit(int x){return x&-x;}
template<unsigned P> struct modint{
    unsigned v; modint():v(0){}
    template<class T> modint(T x):v((x%int(P)+int(P))%int(P)){}
    modint operator-()const{return modint(P-v);}
    modint inv()const{return assert(v),v<=1e5? ::inv[v]:qpow(*this,LL(P)-2);}
    modint&operator+=(const modint&rhs){if(v+=rhs.v,v>=P) v-=P; return *this;}
    modint&operator-=(const modint&rhs){return *this+=-rhs;}
    modint&operator*=(const modint&rhs){v=1ull*v*rhs.v%P; return *this;}
    modint&operator/=(const modint&rhs){return *this*=rhs.inv();}
    friend int raw(const modint&self){return self.v;}
    friend modint qpow(modint a,LL b){modint r=1;for(;b;b>>=1,a*=a) if(b&1) r*=a; return r;}
    friend modint operator+(modint lhs,const modint&rhs){return lhs+=rhs;}
    friend modint operator-(modint lhs,const modint&rhs){return lhs-=rhs;}
    friend modint operator*(modint lhs,const modint&rhs){return lhs*=rhs;}
    friend modint operator/(modint lhs,const modint&rhs){return lhs/=rhs;}
    friend bool operator==(const modint&lhs,const modint&rhs){return lhs.v==rhs.v;}
    friend bool operator!=(const modint&lhs,const modint&rhs){return lhs.v!=rhs.v;}
};
typedef modint<998244353> mint;
template<int N> struct C_prime{
    mint fac[N+10],ifac[N+10];
    C_prime(){
        for(int i=raw(fac[0]=1);i<=N;i++) fac[i]=fac[i-1]*i;
        ifac[N]=1/fac[N];for(int i=N;i>=1;i--) ifac[i-1]=ifac[i]*i;
    }
    mint operator()(int n,int m){return n>=m?fac[n]*ifac[m]*ifac[n-m]:0;}
};
int n,a[20],U,m,sum[1<<15];
vector<mint> multiple(vector<mint> a,vector<mint> b){
    int n=a.size(),m=b.size();
    vector<mint> c(n+m-1);
    for(int i=0;i<n;i++){
        for(int j=0;j<m;j++) c[i+j]+=a[i]*b[j];
    }
    if(c.size()> ::m+1) c.resize(::m+1);
    return c;
}
mint f[210][1<<15];
C_prime<1<<17> binom;
void dp_f(){
    f[0][0]=1;
    for(int i=0;i<m;i++){
        for(int S=0;S<1<<n;S++) if(f[i][S]!=0){
            f[i+1][S]+=f[i][S]/(n-popcount(S));
            int tot=i+1-sum[S];
            if(tot<=0) continue;
            for(int j=0;j<n;j++) if(!conbit(S,j)&&tot>=a[j]){
                f[i+1][S|getbit(j)]+=f[i][S]/(n-popcount(S))*binom(tot-1,a[j]-1);
            }
        }
    }
}
vector<mint> bag[2][1<<8];
void dp_bag(vector<mint> bag[],int n,int s){
    bag[0]={1};
    for(int i=0;i<n;i++){
        bag[1<<i].resize(a[s+i]);
        for(int j=0;j<a[s+i];j++) bag[1<<i][j]=binom.ifac[j];
    }
    for(int S=1;S<1<<n;S++) if(popcount(S)>1){
        bag[S]=multiple(bag[S^lowbit(S)],bag[lowbit(S)]);
    }
}
void dp_merge(){
    for(int i=m;i<=m;i++){
        for(int S=0;S<1<<n;S++) if(i>=sum[S]){
            mint res=0; int T=U^S;
            //res=bag[U^S][i-sum[S]]
            int s=T&(getbit(n/2)-1),t=T>>(n/2);
            for(int j=0;j<=i-sum[S];j++){
                if(j>=bag[0][s].size()||i-sum[S]-j>=bag[1][t].size()) continue;
                res+=bag[0][s][j]*bag[1][t][i-sum[S]-j];
            }
            f[i][S]*=res*binom.fac[i-sum[S]];
        }
    }
}
int main(){
//  #ifdef LOCAL
//      freopen("input.in","r",stdin);
//  #endif
    scanf("%d%d",&n,&m),U=(1<<n)-1;
    for(int i=0;i<n;i++) scanf("%d",&a[i]),sum[1<<i]=a[i];
    for(int S=1;S<1<<n;S++) if(popcount(S)>1) sum[S]=sum[S^lowbit(S)]+sum[lowbit(S)];
    dp_f();
    dp_bag(bag[0],n/2,0),dp_bag(bag[1],n-n/2,n/2);
    dp_merge();
    mint ans=0;
    for(int S=0;S<1<<n;S++) ans+=f[m][S]*popcount(S);
    printf("%d\n",raw(ans));
    return 0;
}
posted @ 2023-09-13 20:50  caijianhong  阅读(18)  评论(0编辑  收藏  举报