http://acm.hdu.edu.cn/showproblem.php?pid=4610
先求出每个数的得分情况,分数和得分状态,(1<<4)种状态
按分数从大到小排序 然后每种状态取一个数(如果有的话)
然后对 dp[i][j] 进行背包 dp[i][j] 表示的是选了i个数选的总状态为j情况下的最大值
然后根据每个 dp[i][j] 对数组剩余的数进行最优选择(在不破坏 最终状态 j 的情况下,尽量选单位分数高的)
最后求最大的情况
代码:
#include<iostream> #include<cstdio> #include<string> #include<cstring> #include<cmath> #include<set> #include<map> #include<stack> #include<vector> #include<algorithm> #include<queue> #include<stdexcept> #include<bitset> #include<cassert> #include<deque> #include<numeric> using namespace std; typedef long long ll; typedef unsigned int uint; const double eps=1e-12; const int INF=0x3f3f3f3f; const ll MOD=1000000007; const int H=1000005; const int K=10005; const int N=1005; const int M=(1<<4); struct node { int a,b; int point,k; }in[N]; bool prime(int x) { if(x==1) return false; for(int i=2;i*i<=x;++i) if(x%i==0) return false; return true; } void get(node &x) { int a=x.a; int num=0,sum=0; for(int i=1;i*i<=a;++i) if(a%i==0) { ++num; sum+=i; int j=a/i; if(i!=j) {++num;sum+=j;} } x.k=0;x.point=0; if(num==2) {x.point++;x.k|=1;} if(prime(num)) {x.point++;x.k|=2;} if(prime(sum)) {x.point++;x.k|=4;} int h=(int)(sqrt(1.0*a)+0.5); int h1=(int)(sqrt(1.0*h)+0.5); if(a==1||((num&1)==0&&((num>>1)&1)==0)||((num&1)==1&&h1*h1==h)) {x.point++;x.k|=8;} } bool cmp(node x,node y) { return x.point>y.point; } int main() { //freopen("data.in","r",stdin); //freopen("1011.in","r",stdin); //freopen("1011.out","w",stdout); int T; scanf("%d",&T); while(T--) { int n,k; scanf("%d %d",&n,&k); for(int i=0;i<n;++i) { scanf("%d %d",&in[i].a,&in[i].b); get(in[i]); if(i>0) printf(" "); printf("%d",in[i].point); } printf("\n"); int bit[4]; for(int i=0;i<4;++i) scanf("%d",&bit[i]); sort(in,in+n,cmp); int dp[20][20]; memset(dp,-1,sizeof(dp)); dp[0][0]=0; bool had[20]; memset(had,false,sizeof(had)); for(int i=0;i<n;++i) if(had[in[i].k]==false) { had[in[i].k]=true; node &w=in[i]; for(int i=M-1;i>=0;--i) for(int j=0;j<M;++j) if(dp[i][j]!=-1) { int l=i+1; int r=(j|w.k); dp[l][r]=max(dp[l][r],dp[i][j]+w.point); } w.b--; } int ans=-INF; for(int i=0;i<=16;++i) for(int j=0;j<16;++j) if(i<=k&&dp[i][j]!=-1) { int sum=dp[i][j]; for(int l=0;l<4;++l) if((j&(1<<l))==0) sum+=bit[l]; int w=k-i; for(int l=0;w>0&&l<n;++l) if((j|in[l].k)==j) { if(w>=in[l].b) { w-=in[l].b; sum+=(in[l].b*in[l].point); }else { sum+=(w*in[l].point); w=0; } if(w==0) break; } if(w==0) ans=max(ans,sum); } printf("%d\n",ans); } return 0; }