[PKUSC2018]最大前缀和

题目

只会\(O(2^nn^2)\)的暴力子集卷积啊

首先第一反应是算贡献,我们先求出每一个子集的子集和\(sum_i\),之后考虑\(i\)这个子集在多少种排列中成为了最大前缀和

由于一个排列的最大前缀和可能有好几个,于是我们强行规定最大前缀和为最大且出现位置最靠前的前缀和

如果我们能求出一个\(dp_i\)表示\(i\)这个集合有多少种排列使得\(i\)就是最大前缀和,我们只需要让剩下的数组成的排列在任何时候前缀和都不大于0就好了

\(f_i\)表示\(i\)这个集合有多少种排列在任何时刻前缀和都不大于\(0\),这样的话答案就是\(\sum_{i\subset S }sum_i\times dp_i\times f_{S\bigoplus i}\)

这个\(f\)\(O(n2^n)\) 的时间内就很容易求出来,现在的问题就是求出\(dp\)

随便胡了一个子集卷积的做法发现会算重,于是我们考虑正难则反,我们算一下\(i\)有多少个排列使得最大前缀和不是\(i\),之后那\(|i|!\)一减就好了

显然我们可以枚举一个子集\(t\)成为最大前缀和,之后让后面不能有更大的前缀和就好了,于是让剩下的数排列在任何时候前缀和不大于\(0\)就好了

于是\(dp_i=|i|!-\sum_{t\subset i}dp_t\times f_{t\bigoplus i}\)

显然这个是一个子集卷积的形式我们可以强行上\(fwt\)优化成\(O(2^nn^2)\),由于我们中间过程还要卷回来,所以常数很大,卡卡常就好了

代码

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define re register
const int mod=998244353;
const int maxn=(1<<20)+5;
int sum[maxn];int a[maxn];
int g[21][maxn],dp[21][maxn],cnt[maxn];
int n,len,fac[21],st[21][maxn>>1],tp[21];
inline int qm(int a) {return a>=mod?a-mod:a;}
inline int sqm(int a) {return a<0?a+mod:a;}
inline void Fwt(int *f) {
    for(re int ln=1,i=2;i<=len;i<<=1,ln=i>>1)
        for(re int l=0;l<len;l+=i)
            for(re int x=l;x<l+ln;++x)
                f[x+ln]=qm(f[x+ln]+f[x]); 
}
inline void Ifwt(int *f) {
	for(re int ln=1,i=2;i<=len;i<<=1,ln=i>>1)
		for(re int l=0;l<len;l+=i)
			for(re int x=l;x<l+ln;++x)
				f[x+ln]=sqm(f[x+ln]-f[x]);
}
int main() {
    scanf("%d",&n);len=(1<<n);fac[0]=1;
    for(re int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%mod;
    for(re int i=0;i<n;i++) scanf("%d",&a[i]);
    for(re int i=1;i<len;i++) 
        for(re int j=0;j<n;j++)
            if((1<<j)&i) sum[i]=(a[j]+sum[i])%mod;
    for(re int i=1;i<len;i++) cnt[i]=cnt[i>>1]+(i&1);
    for(re int i=1;i<len;i++) st[cnt[i]][++tp[cnt[i]]]=i;
    g[0][0]=1;
    for(re int i=0;i<len;i++) {
        if(sum[i]>0) continue;
        for(re int j=0;j<n;j++) {
            if(i&(1<<j)) continue;
            if(sum[i|(1<<j)]<=0) 
                g[cnt[i]+1][i|(1<<j)]=qm(g[cnt[i]+1][i|(1<<j)]+g[cnt[i]][i]);
        }
    }
    for(re int i=0;i<=n;i++) Fwt(g[i]);
    for(re int i=0;i<n;i++) dp[1][1<<i]=1;
    Fwt(dp[1]);
    for(re int i=2;i<=n;i++) {
        for(re int j=1;j<i;j++)
            for(re int k=0;k<len;k++)
                dp[i][k]=qm(dp[i][k]+1ll*g[j][k]*dp[i-j][k]%mod);
        Ifwt(dp[i]);
        for(re int j=1;j<=tp[i];j++) dp[i][st[i][j]]=sqm(fac[i]-dp[i][st[i][j]]);
        Fwt(dp[i]);
    }
    for(re int i=1;i<=n;i++) Ifwt(dp[i]);
    for(re int i=1;i<len;i++) if(sum[i]<0) sum[i]=(sum[i]+mod)%mod;
    int ans=0;len--;
    for(re int i=1;i<=len;i++)
        ans=qm(ans+1ll*sum[i]*dp[cnt[i]][i]%mod*g[n-cnt[i]][len^i]%mod);
    printf("%d\n",ans);
    return 0;
}
posted @ 2019-06-03 21:03  asuldb  阅读(232)  评论(0编辑  收藏  举报