codevs1227 方格取数2

题目描述 Description

给出一个n*n的矩阵,每一格有一个非负整数Aij,(Aij <= 1000)现在从(1,1)出发,可以往右或者往下走,最后到达(n,n),每达到一格,把该格子的数取出来,该格子的数就变成0,这样一共走K次,现在要求K次所达到的方格的数的和最大

输入描述 Input Description

第一行两个数n,k(1<=n<=50, 0<=k<=10)

接下来n行,每行n个数,分别表示矩阵的每个格子的数

输出描述 Output Description

一个数,为最大和

样例输入 Sample Input

3 1

1 2 3

0 2 1

1 4 2

样例输出 Sample Output

11

数据范围及提示 Data Size & Hint

1<=n<=50, 0<=k<=10

 

拆点 x与x'之间连两条边,一条费用为格子里的数值,最大流量为1,另一条费用为0,最大流量INF

费用流

//Serene
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<cmath>
using namespace std;
const int maxn=2*50*50+10,maxm=4*maxn+maxn,INF=0x3f3f3f3f;
int n,k,tu[maxn],S,T;

int aa;char cc;
int read() {
	aa=0;cc=getchar();
	while(cc<'0'||cc>'9') cc=getchar();
	while(cc>='0'&&cc<='9') aa=aa*10+cc-'0',cc=getchar();
	return aa;
}

struct Node{
	int x,y,cap,flow,w;
	Node(){}
	Node(int x,int y,int cap,int w) :x(x),y(y),cap(cap),w(w){}
}node[2*maxm];

int fir[maxn],nxt[2*maxm],e=1;
void add(int x,int y,int z,int w) {
	node[++e]=Node(x,y,z,w); nxt[e]=fir[x];fir[x]=e;
	node[++e]=Node(y,x,0,-w); nxt[e]=fir[y];fir[y]=e;
}

int zz[maxn],from[maxn],dis[maxn];bool vis[maxn];
bool spfa() {
	int s=1,t=0,x,y,z;
	memset(dis,-1,sizeof(dis));
	memset(zz,0,sizeof(zz));
	zz[++t]=S;vis[S]=1;dis[S]=0;
	while(s<=t) {
		x=zz[s%maxn];
		for(y=fir[x];y;y=nxt[y]) {
			z=node[y].y;
			if(dis[z]>=dis[x]+node[y].w||node[y].flow>=node[y].cap) continue;
			if(!vis[z]) {
				t++; zz[t%maxn]=z;
				vis[z]=1;
			}
			from[z]=y;
			dis[z]=dis[x]+node[y].w;
		}
		vis[x]=0;s++;
	}
	return dis[T]!=-1;
}

int MCMF() {
	int rs=0,now;
	while(spfa()) {
		now=INF;
		for(int i=T;i!=S;i=node[from[i]].x) now=min(now,node[from[i]].cap-node[from[i]].flow);
		for(int i=T;i!=S;i=node[from[i]].x) {
			node[from[i]].flow+=now;
			node[from[i]^1].flow-=now;
			rs+=now*node[from[i]].w;
		}
	}
	return rs;
}

int main() {
	n=read();k=read();int x;S=2*n*n+1;T=S+1;
	for(int i=1;i<=n*n;++i) tu[i]=read();
	for(int i=1;i<=n;++i) for(int j=1;j<=n;++j) {
		x=n*(i-1)+j;
		add(x+n*n,x,1,tu[x]);add(x+n*n,x,k,0);
		if(i!=n) add(x,x+n+n*n,k,0);
		if(j!=n) add(x,x+1+n*n,k,0);
	}
	add(S,1+n*n,k,0); add(n*n,T,k,0);
	printf("%d",MCMF());
	return 0;
}

  

posted @ 2017-08-28 19:39  shixinyi  阅读(149)  评论(0编辑  收藏  举报