(多校)按位或(or)

首次考场推容斥 AC
先转换题意
给定 \(m\) 个位置,限制那 \(n\) 个数的二进制位只能在这几位上
使得每一二进制位均被覆盖过至少一次
显然的容斥题
我们令 \(S(k)\) 表示至少有 \(k\) 个二进制位没有覆盖选数的总方案数
那答案为

\[\sum_{k=0}^{m}(-1)^kS(k) \]

考虑若何求 \(S(k)\)
规定:

1类:二进制位 %3=1
2类:二进制位 %3=2
1类共有 a个
2类共有 b个

\(dp_{i,j,0/1/2}\) 表示选了 \(i\) 个 1类,\(j\) 个 2类,组合选数的模数为 \(0/1/2\) 的方案数
直接转移即可

\[S(k)=\sum_{i=0}^{a}\sum_{j=1}^{b}[i+j=m-k]\binom{a}{i}\binom{b}{j}dp_{i,j,0}^n \]

Code
#include <bits/stdc++.h>
#define re register
#define int long long
#define ll long long
#define pir make_pair
#define fr first 
#define sc second
#define db double
using namespace std;
const int mol=998244353;
const int maxn=2e7+10;
const int INF=1e9+10;
inline int qpow(int a,int b) { int ans=1; while(b) { if(b&1) ans=(1ll*ans*a)%mol; a=(1ll*a*a)%mol; b>>=1; } return ans; }
inline int read() {
    int s=0,w=1; char ch=getchar();
    while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
    while(ch>='0'&&ch<='9') { s=s*10+ch-'0'; ch=getchar(); }
    return s*w;
}

int n,t,num[3],s[66],bin[66],dp[66][66][3],ks[2][3],fac[66],inv[66];
inline void ad(int &x) { x= x>=mol? x-mol:x; }
inline int C(int n,int m) { return fac[n]*inv[m]%mol*inv[n-m]%mol; }
signed main(void) {
    freopen("erp.in","r",stdin); freopen("erp.out","w",stdout);
    n=read(); t=read(); int m=0;
    bin[0]=1; for(re int i=1;i<=63;i++) bin[i]=bin[i-1]*2;
    for(re int i=0;i<=63;i++) if(t&bin[i]) { ++num[bin[i]%3]; ++m; }
    fac[0]=1; for(re int i=1;i<=65;i++) fac[i]=fac[i-1]*i%mol;
    inv[65]=qpow(fac[65],mol-2); for(re int i=65;i>=1;i--) inv[i-1]=inv[i]*i%mol;
    dp[0][0][0]=1;
    for(re int b=0;b<=num[1];b++) for(re int c=0;c<=num[2];c++) {
        if(b) {
            int a=b-1; for(re int k=0;k<3;k++) dp[b][c][k]=dp[a][c][k];
            for(re int k=0;k<3;k++) ad(dp[b][c][(k+1)%3]+=dp[a][c][k]);
        }
        else if(c) {
            int a=c-1; for(re int k=0;k<3;k++) dp[b][c][k]=dp[b][a][k];
            for(re int k=0;k<3;k++) ad(dp[b][c][(k+2)%3]+=dp[b][a][k]);
        }
        ad(s[m-b-c]+=C(num[1],b)*C(num[2],c)%mol*qpow(dp[b][c][0],n)%mol);
    }
    int ans=0;
    for(re int k=0;k<=m;k++) { (ans+=(int)pow(-1,k)*s[k]%mol)%=mol; } 
    printf("%lld\n",(ans+mol)%mol);
}
posted @ 2021-11-03 10:12  zJx-Lm  阅读(31)  评论(0编辑  收藏  举报