魔改森林 题解(dp+容斥+组合数学)

题目链接

题目大意

给你一个n*m的方格图,中间有k个障碍,要你求从左下角到右上角有多少种方案mod 998244353

题目思路

这个题目真是很oi......
首先\(max(n,m)<=1e3\) 直接\(O(N^2)dp\)

如果方格数量很多,观察障碍物很少,则可以想到容斥的思维

然后再用组合数学预处理一下

一定要注意要预处理到2e5 wa了我一辈子

代码

#include<bits/stdc++.h>
#define fi first
#define se second
#define debug cout<<"I AM HERE"<<endl;
using namespace std;
typedef long long ll;
const int maxn=2e5+5,inf=0x3f3f3f3f,mod=998244353;
const int eps=1e-6;
int n,m,k;
ll fac[maxn];
ll dp[1000+5][1000+5];
bool mp[1000+5][1000+5];
pair<int,int> pa[maxn];
ll qpow(ll a,ll b){
    ll ans=1,base=a;
    while(b){
        if(b&1){
            ans=ans*base%mod;
        }
        b=b>>1;
        base=base*base%mod;
    }
    return ans;
}
ll C(int a,int b){
    return 1ll*fac[a]*qpow(fac[b],mod-2)%mod*qpow(fac[a-b],mod-2)%mod;
}
void init(){
    fac[0]=1;
    for(int i=1;i<=2e5;i++){
        fac[i]=fac[i-1]*i%mod;
    }
}
bool cmp(pair<int,int> a,pair<int,int> b){
    if(a.fi==b.fi){
        return a.se<b.se;
    }else{
        return a.fi>b.fi;
    }
}
void solve1(){
    for(int i=n+1;i>=1;i--){
        for(int j=1;j<=m+1;j++){
            if(i==n+1&&j==1) dp[i][j]=1;
            else if(mp[i][j]) dp[i][j]=0;
            else    dp[i][j]=(dp[i+1][j]+dp[i][j-1])%mod;
        }
    }
    printf("%lld\n",dp[1][m+1]);
}
void solve2(){
    //(n,0) (0,m)
    sort(pa+1,pa+1+k,cmp);// 排序
    ll ans=C(n+m,n); //总共
    for(int i=1;i<=k;i++){// 走过一个点
        ans-=1ll*C(n-pa[i].fi+pa[i].se,pa[i].se)*C(pa[i].fi+m-pa[i].se,pa[i].fi)%mod;
    }
    bool flag=1;
    for(int i=1;i<k;i++){// 走过两个点
        for(int j=i+1;j<=k;j++){
            if(pa[i].se>pa[j].se){
                flag=0;
                continue;
            }
            ans+=1ll*C(n-pa[i].fi+pa[i].se,pa[i].se)*C(pa[i].fi-pa[j].fi+pa[j].se-pa[i].se,pa[i].fi-pa[j].fi)%mod*C(pa[j].fi+m-pa[j].se,pa[j].fi)%mod;
        }
    }
    if(flag&&k==3){ //走过三个点
        ans-=1ll*C(n-pa[1].fi+pa[1].se,pa[1].se)*C(pa[2].fi-pa[1].fi+pa[1].se-pa[2].se,pa[2].fi-pa[1].fi)%mod*C(pa[3].fi-pa[2].fi+pa[2].se-pa[3].se,pa[3].fi-pa[2].fi)%mod*C(pa[3].fi+m-pa[3].se,pa[3].fi)%mod;
    }
    ans=(ans%mod+mod)%mod;
    printf("%lld\n",ans);
}
int main(){
    init();
    scanf("%d%d%d",&n,&m,&k);
    for(int i=1;i<=k;i++){
        scanf("%d%d",&pa[i].fi,&pa[i].se);
        pa[i].fi--,pa[i].se--;
        // 格子与格点的区别
        if(k>=10)mp[pa[i].fi+1][pa[i].se+1]=1;
    }
    if(k>=10){
        solve1();
    }else{
        solve2();
    }
    return 0;
}

posted @ 2021-01-27 15:20  hunxuewangzi  阅读(84)  评论(0编辑  收藏  举报