思路太繁琐了 ,实在不想解释了
代码:
#include<iostream> #include<cstdio> #include<string> #include<cstring> #include<cmath> #include<set> #include<map> #include<stack> #include<vector> #include<algorithm> #include<queue> #define ull unsigned long long #define ll long long #define lint long long using namespace std; const int INF=0x3f3f3f3f; const int N=53; ll c[N][N]; ll dp[N][N][N]; ll dp1[N][N][N]; class Excavations { public: long long count(vector <int> kind, vector <int> depth, vector <int> found, int K) { memset(c,0,sizeof(c)); for(int i=0;i<N;++i) for(int j=0;j<=i;++j) if(j==0||i==j) c[i][j]=1; else c[i][j]=c[i-1][j]+c[i-1][j-1]; int n=kind.size(); int m=found.size(); set<int>st; for(int i=0;i<n;++i) st.insert(depth[i]); int dep=1; for(set<int>::iterator it=st.begin();it!=st.end();++it,++dep) { for(unsigned int i=0;i<depth.size();++i) if((*it)==depth[i]) depth[i]=dep; } sort(found.begin(),found.end()); vector<int>f[N]; for(int i=0;i<m;++i) { for(int j=0;j<n;++j) if(kind[j]==found[i]) f[i].push_back(depth[j]); sort(f[i].begin(),f[i].end()); } memset(dp,0,sizeof(dp)); dp[0][0][0]=1; for(int i=0;i<m;++i) for(int d=0;d<=n;++d) for(int w=0;w<K;++w) if(dp[i][d][w]>0) { for(unsigned int l=0;l<f[i].size();++l) { for(int x=0;x<=l&&w+x+1<=K;++x) { dp[i+1][max(d,f[i][l])][w+x+1]+=dp[i][d][w]*c[l][x]; } } } bool fd[N]; memset(fd,false,sizeof(fd)); for(int i=0;i<m;++i) fd[found[i]]=true; vector<int>vd; int h=0; for(int i=0;i<n;++i) if(!fd[kind[i]]) { vd.push_back(depth[i]); h=max(h,depth[i]); } sort(vd.begin(),vd.end()); int ln=vd.size(); memset(dp1,0,sizeof(dp1)); dp1[ln][n+1][0]=1; for(int i=ln;i>0;--i) for(int d=1;d<=n+1;++d) for(int w=0;w<=ln;++w) if(dp1[i][d][w]) { dp1[i-1][d][w]+=dp1[i][d][w]; dp1[i-1][vd[i-1]][w+1]+=dp1[i][d][w]; } ll s[N]={0}; for(int i=0;i<n;++i) if(fd[kind[i]]) { for(int j=1;j<=depth[i];++j) s[j]++; } ll sum=0; for(int d=2;d<=n+1;++d) for(int w=0;w<=ln;++w) if(dp1[0][d][w]) { for(int d1=1;d1<d;++d1) for(int w1=1;w1+w<=K;++w1) if(dp[m][d1][w1]) { if(s[d]>=K-w1-w) { sum+=(dp1[0][d][w]*dp[m][d1][w1]*c[s[d]][K-w-w1]); } } } return sum; } };