I.Tower Defense
给你p个重塔,q个轻塔,把这些塔放在n*m的图中,这些塔会相互攻击同行同列的,轻塔不能受到攻击,重塔能承受一个塔的攻击,
问放的方法数。
先假定n < m。
可以先枚举放轻塔的个数为s,显然,方法数为C(n,s) * m * (m-1) * ... * (m-s+1) ,放完之后我们可以发现图其实缩小成为了一个(n-s)*(m-s)的图。
然后放重塔,由于重塔可以承受一个塔的攻击,dp求一下方案,令dp(i,j,k) 表示i*j的图中放k个重塔的方法,通过在图的第一行进行限定条件枚举。
可分为3个小部分:
1.第一行不放重塔 dp(i,j,k) += dp(i-1,j,k)
2.第一行放一个重塔,又分两种情况:
A:同一列不放重塔 dp(i,j,k) += j*dp(i-1,j-1,k-1)
B:同一列放重塔 dp(i,j,k) += j*(i-1)*dp(i-2,j-1,k-2)
3.第一行放两个重塔
dp(i,j,k) += C(j,2)*dp(i-1,j-2,k-2)
求出dp数组之后即总方法数为segma(0,q,i) segma(0,p,j) C(n,i)*m*...*(m-i+1)*dp(n-i,m-i,j)
由于不能不放,所以需要最后减去1.
时间复杂度为K*200^3,K为一常数。
#include <cmath> #include <cstdio> #include <cstdlib> #include <cassert> #include <cstring> #include <set> #include <map> #include <list> #include <queue> #include <string> #include <iostream> #include <algorithm> #include <functional> #include <stack> #include <bitset> using namespace std; typedef long long ll; #define INF (0x3f3f3f3f) #define maxn (1000005) #define mod 1000000007 #define ull unsigned long long ll C[205][205],dp[205][205][205]; void init(){ for(int i = 0;i <= 200;++i) C[i][0] = 1; for(int i = 1;i <= 200;++i){ for(int j = 1;j <= i;++j){ C[i][j] = C[i-1][j] + C[i-1][j-1]; if(C[i][j] >= mod) C[i][j] %= mod; } } for(int i = 0;i <= 200;++i) for(int j = 0;j <= 200;++j) dp[i][j][0] = 1; for(int i = 1;i <= 200;++i){ for(int j = 1;j <= 200;++j){ for(int k = 1;k <= 200;++k){ //第一行不取 if(i == 2 && j == 2){ int t = 1; } dp[i][j][k] += dp[i-1][j][k]; //第一行取一个 dp[i][j][k] += j * dp[i-1][j-1][k-1]%mod;//对应的列不取 if(dp[i][j][k] >= mod) dp[i][j][k] %= mod; if(i >= 2 && k >= 2) dp[i][j][k] += j * (i-1) * dp[i-2][j-1][k-2]%mod;//对应的列取 if(dp[i][j][k] >= mod) dp[i][j][k] %= mod; //第一行取两个 if(j >= 2 && k >= 2) dp[i][j][k] += C[j][2]*dp[i-1][j-2][k-2]%mod; if(dp[i][j][k] >= mod) dp[i][j][k] %= mod; } } } } ll quickpow(ll x,ll y){ ll ans = 1; while(y){ if(y & 1){ ans = ans * x; if(ans >= mod) ans %= mod; } x *= x; if(x >= mod) x %= mod; y >>= 1; } return ans; } int main() { int T; int n,m,p,q; init(); scanf("%d",&T); while(T--){ scanf("%d%d%d%d",&n,&m,&p,&q); if(n > m) swap(n,m); int li = min(q,n); ll s = 1,ans = 0; for(int i = 0;i <= li;++i){ for(int j = 0;j <= p;++j){ ans = ans + C[n][i] * s % mod * dp[(n-i)][(m-i)][j] % mod; if(ans >= mod) ans %= mod; } s = s * (m - i); if(s >= mod) s %= mod; } --ans; if(ans < 0) ans += mod; printf("%lld\n",ans); } return 0; }