兔子的排列 题解(分治+记忆化搜索)
题目链接
题目大意
题目的意思就是一个序列本来是0,1,2...n-1.然后要你变成p[0],p[1],p[2]...p[n-1],通过交换相邻的两个数每两个相邻的数都可以交换一次,你可以选择交换的顺序.问你有多少种交换顺序?
题目思路
有两种方法,一种复杂度\(O(n^5)\) 一种复杂度\(O(n^3)\)
首先如果一个点换了,那么相当于切断了两边,则可以两边单独考虑
则要想到分治
\(dp[l][r][bg][ed]\) 代表区间l-r中开头为bg,结尾是ed,
最不好理解的是为什么枚举断点时要乘 C(r-l-1,i-l)
因为这个区间总共要交换r-l次,而中间交换了1次
所以左边和右边总共交换次数为r-l-1次
而左边的交换次数为i-l次
则相当于要在r-l-1次中给左边的交换次数安排位置
有点难理解,自己可以思考下
优化到\(O(n^3)\)的方法就是直接把dp数组变成\(dp[l][r]\)直接省去两维
因为你划分两段后你用前缀和算一下就能直接判断是否可以满足划分
代码1
#include<bits/stdc++.h>
#define debug cout<<"I AM HERE"<<endl;
using namespace std;
typedef long long ll;
const int maxn=50+5,inf=0x3f3f3f3f,mod=1e9+7;
const int eps=1e-6;
int n;
int a[maxn];
ll fac[maxn];
ll dp[maxn][maxn][maxn][maxn];
ll qpow(ll a,ll b){
ll ans=1,base=a;
while(b){
if(b&1){
ans=ans*base%mod;
}
base=base*base%mod;
b=b>>1;
}
return ans;
}
void init(){
fac[0]=1;
for(int i=1;i<=50;i++){
fac[i]=fac[i-1]*i%mod;
}
}
ll C(int a,int b){ // 计算 C(a,b)
return fac[a]*qpow(fac[b],mod-2)%mod*qpow(fac[a-b],mod-2)%mod;
}
ll dfs(int l,int r,int bg,int ed){
if(dp[l][r][bg][ed]!=-1){
}else if(l==r){ // 长度为1
dp[l][r][bg][ed]=(a[l]==bg);
}else{
ll ans=0;
for(int i=l;i<r;i++){
int x1,y1,x2,y2;
// 注意bg和ed已经改变了
// [bg+1,ed-1]还没改变
if(r-l==1){ // 长度为2
x1=y1=ed;
x2=y2=bg;
}else if(i==l){
x1=y1=i+1;
x2=bg;
y2=ed;
}else if(i==r-1){
x1=bg;
y1=ed;
x2=y2=i;
}else{
x1=bg;
y1=i+1;
x2=i;
y2=ed;
}
ans=(ans+dfs(l,i,x1,y1)*dfs(i+1,r,x2,y2)%mod*C(r-l-1,i-l))%mod;
}
dp[l][r][bg][ed]=ans;
}
return dp[l][r][bg][ed];
}
int main(){
init();
memset(dp,-1,sizeof(dp));
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
a[i]++;
}
printf("%lld\n",dfs(1,n,1,n));
return 0;
}
代码2
#include<bits/stdc++.h>
#define debug cout<<"I AM HERE"<<endl;
using namespace std;
typedef long long ll;
const int maxn=50+5,inf=0x3f3f3f3f,mod=1e9+7;
const int eps=1e-6;
int n;
int a[maxn],b[maxn];// b->a;
int prea[maxn],preb[maxn];
ll fac[maxn];
ll dp[maxn][maxn];
ll qpow(ll a,ll b){
ll ans=1,base=a;
while(b){
if(b&1){
ans=ans*base%mod;
}
base=base*base%mod;
b=b>>1;
}
return ans;
}
void init(){
fac[0]=1;
for(int i=1;i<=50;i++){
fac[i]=fac[i-1]*i%mod;
}
}
ll C(int a,int b){ // 计算 C(a,b)
return fac[a]*qpow(fac[b],mod-2)%mod*qpow(fac[a-b],mod-2)%mod;
}
ll dfs(int l,int r){
if(dp[l][r]!=-1){
}else if(l==r){ // 长度为1
dp[l][r]=(a[l]==b[l]);
}else{
ll ans=0;
for(int i=l;i<r;i++){
swap(b[i],b[i+1]);
preb[i]+=b[i]-b[i+1];
if(preb[i]-preb[l-1]==prea[i]-prea[l-1]&&preb[r]-preb[i]==prea[r]-prea[i]){
ans=(ans+dfs(l,i)*dfs(i+1,r)%mod*C(r-l-1,i-l))%mod;
}
preb[i]-=b[i]-b[i+1];
swap(b[i],b[i+1]);
}
dp[l][r]=ans;
}
return dp[l][r];
}
int main(){
init();
memset(dp,-1,sizeof(dp));
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
a[i]++;
b[i]=i;
prea[i]=prea[i-1]+a[i];
preb[i]=preb[i-1]+b[i];
}
printf("%lld\n",dfs(1,n));
return 0;
}
不摆烂了,写题