c++实现unet
#include<torch/torch.h> #include<iostream> #include<vector> #include<cassert> #include<stdlib.h> #include<unordered_map> #include<fstream> class double_conv:public torch::nn::Module { public: torch::nn::Conv2d conv1,conv2; torch::nn::BatchNorm bn1,bn2; int in_ch,out_ch; public: double_conv(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch),conv1(torch::nn::Conv2dOptions(in_ch,out_ch,3).padding(1)),bn1(out_ch), conv2(torch::nn::Conv2dOptions(out_ch,out_ch,3).padding(1)),bn2(out_ch) { register_module("conv1",conv1); register_module("conv2",conv2); register_module("bn1",bn1); register_module("bn2",bn2); } torch::Tensor forward(torch::Tensor x) { x = conv1->forward(x); x = bn1->forward(x); x = torch::relu(x); x = conv2->forward(x); x = bn2->forward(x); x = torch::relu(x); return x; } }; class inconv:public torch::nn::Module { public: int in_ch,out_ch; public: inconv(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch){} torch::Tensor forward(torch::Tensor x) { double_conv dc(in_ch,out_ch); x = dc.forward(x); return x; } }; class down:public torch::nn::Module { public: int in_ch,out_ch; public: down(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch){} torch::Tensor forward(torch::Tensor x) { x = torch::max_pool2d(x,2); double_conv dc(in_ch,out_ch); x = dc.forward(x); return x; } }; class up:public torch::nn::Module { public: int in_ch,out_ch; torch::nn::Conv2d upconv; torch::nn::Conv2d conv1,conv2; torch::nn::BatchNorm bn1,bn2; torch::Tensor x; public: up(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch),upconv(torch::nn::Conv2dOptions(in_ch,out_ch,4).padding(1).stride(2).transposed(new bool(true))), conv1(torch::nn::Conv2dOptions(out_ch,out_ch,3).padding(1)),bn1(out_ch),conv2(torch::nn::Conv2dOptions(out_ch,out_ch,3).padding(1)),bn2(out_ch) { register_module("upconv",upconv); register_module("conv1",conv2); register_module("conv2",conv2); register_module("bn1",bn1); register_module("bn2",bn2); } torch::Tensor forward(torch::Tensor x1,torch::Tensor x2) { x = upconv->forward(x1); x = torch::cat({x,x2},1); double_conv dc(x.size(1),out_ch); x = dc.forward(x); //x = conv1->forward(x); //x = bn1->forward(x); //x = torch::relu(x); //x = conv2->forward(x); //x = bn2->forward(x); //x = torch::relu(x); return x; } }; class outconv:public torch::nn::Module { public: int in_ch,out_ch; torch::nn::Conv2d conv; public: outconv(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch),conv(torch::nn::Conv2dOptions(in_ch,out_ch,1).padding(0)) { register_module("conv",conv); } torch::Tensor forward(torch::Tensor x) { return conv->forward(x); } }; class unet:public torch::nn::Module { public: int n_ch,n_class; inconv *iconv= new inconv(n_ch,64); down *down1= new down(64,256); down *down2= new down(256,512); down *down3= new down(512,512); down *down4= new down(512,512); up *up1= new up(512,256); up *up2= new up(256,128); up *up3= new up(128,64); up *up4= new up(64,64); outconv *oconv= new outconv(64,n_class); torch::Tensor x1,x2,x3,x4,x5; public: unet(int n_ch,int n_class):n_ch(n_ch),n_class(n_class){} torch::Tensor forward(torch::Tensor x) { x1 = iconv->forward(x); x2 = down1->forward(x1); x3 = down2->forward(x2); x4 = down3->forward(x3); x5 = down4->forward(x4); x = up1->forward(x5,x4); x = up2->forward(x,x3); x = up3->forward(x,x2); x = up4->forward(x,x1); x = oconv->forward(x); return x; } }; std::vector<float> Tokenize(const std::string& str,const std::string& delimiters) { std::vector<float> tokens; std::string::size_type lastPos = str.find_first_not_of(delimiters, 0); std::string::size_type pos = str.find_first_of(delimiters, lastPos); while (std::string::npos != pos || std::string::npos != lastPos) { tokens.push_back(std::atof(str.substr(lastPos, pos - lastPos).c_str())); lastPos = str.find_first_not_of(delimiters, pos); pos = str.find_first_of(delimiters, lastPos); } return tokens; } std::vector<std::vector<float>> readTxt(std::string file) { std::ifstream infile; infile.open(file.data()); assert(infile.is_open()); std::string s; std::vector<float> vec; std::vector<std::vector<float>> res; while(getline(infile,s)) { std::string tt= static_cast<std::string>(s); vec = Tokenize(tt, " "); res.push_back(vec); } infile.close(); std::cout<<"gdood"<<std::endl; return res; } torch::Tensor float2TensorLabel() { static float tt[2478][3125]={0}; //memset(tt,0,sizeof(tt)); std::vector<std::vector<float>> vec = readTxt("/Users/yanlang/unet/mx-unet/U-Net/LabelData.txt"); int ch = vec.size(); int len = vec[0].size(); for(int i=0;i<ch;i++) { for(int j=0;j<len;j++) { tt[i][j]=vec[i][j]; } } torch::Tensor tmask = torch::CPU(torch::kFloat).tensorFromBlob(tt,{2478,3125}); return tmask; } torch::Tensor float2TensorData() { static float tt[7][2478*3125] = {0}; std::vector<std::vector<float>> vec = readTxt("/Users/yanlang/unet/mx-unet/U-Net/ImageData.txt"); int ch = vec.size(); int len = vec[0].size(); for(int i=0;i<ch;i++) { for(int j=0;j<len;j++) { tt[i][j]=vec[i][j]; } } torch::Tensor tdata = torch::CPU(torch::kFloat).tensorFromBlob(tt,{7,2478,3125}); return tdata; } int imgH=256; int imgW=256; torch::Tensor RandData(torch::Tensor data,int hight,int width) { //torch::Tensor datat = torch::squeeze(data); torch::Tensor tmp = torch::zeros({7,imgH,imgW}); for(int i=hight;i<hight+imgH;i++) { for(int j=width;j<width+imgW;j++) { for(int k=0;k<7;k++) { tmp[k][i-hight][j-width]=data[k][i][j]; } } } return tmp; } torch::Tensor RandMask(torch::Tensor label, int hight,int width) { torch::Tensor tmp = torch::zeros({imgH,imgW}); for(int i=hight;i<hight+imgH;i++) { for(int j=width;j<width+imgW;j++) { tmp[i-hight][j-width]=label[i][j]; } } return tmp; } std::vector<torch::Tensor> DataLoader(torch::Tensor data,torch::Tensor label,int batch_size) { int imghight = data.size(1); int imgwidth = data.size(2); int randhight,randwidth; torch::Tensor resdata = torch::zeros({batch_size,7,imgH,imgW}); torch::Tensor reslabel = torch::zeros({batch_size,imgH,imgW}); for(int i=0;i<batch_size;i++) { randhight = rand()%(imghight-imgH-1); randwidth = rand()%(imgwidth-imgW-1); resdata[i] = RandData(data,randhight,randwidth); reslabel[i] = RandMask(label,randhight,randwidth); } return {resdata,reslabel}; } torch::autograd::Variable Get_predData(torch::autograd::Variable data) { //torch::autograd::Variable datat = torch::unsqueeze(data,0); torch::autograd::Variable tmp = torch::zeros({7,imgH,imgW}); for(int i=500;i<756;i++) { for(int j=500;j<756;j++) { for(int k=0;k<7;k++) { tmp[k][i-500][j-500]=data[k][i][j]; } } } return torch::unsqueeze(tmp,0); } void write2Txt(torch::autograd::Variable data) { std::ofstream fout("tresult.txt"); for(int i=0;i<data.size(0);i++) { for(int j=0;j<data.size(1);j++) { fout<<data[i][j]<<std::endl; } } fout.close(); } void saveModel(std::vector<torch::Tensor> weights,std::vector<std::string> key) { std::ofstream fout("unet.txt"); //std::unordered_map<std::string,torch::Tensor> mp; for(int i=0;i<weights.size();i++) { fout<<key[i]<<std::endl; fout<<weights[i]<<std::endl; } fout.close(); } void trainConvNet(unet model) { torch::optim::SGD optimizer(model.parameters(),/*lr=*/0.01); torch::Tensor pred; std::cout<<"load data ......"<<std::endl; torch::autograd::Variable data = torch::autograd::make_variable(float2TensorData()); torch::autograd::Variable label = torch::autograd::make_variable(float2TensorLabel()); std::cout<<"done!!"<<std::endl; torch::Tensor train_data,train_label; std::vector<torch::Tensor> vecdata; for(int epoch=0;epoch<20;epoch++) { vecdata = DataLoader(data,label,2); std::cout<<"vecdata after done!!"<<std::endl; train_data = vecdata[0]; std::cout<<"train_data after done"<<std::endl; train_label = vecdata[1]; std::cout<<train_label.size(0)<<std::endl; std::cout<<"train_label after done"<<std::endl; pred = model.forward(train_data); auto loss = torch::nll_loss2d(pred,torch::_cast_Long(train_label));//torch::_cast_Long() std::cout<<"the loss is"<<loss<<std::endl; optimizer.zero_grad(); loss.backward(); optimizer.step(); } std::vector<torch::Tensor> vecValue; std::vector<std::string> vecKey; torch::nn::ParameterCursor tt = model.parameters(); for(auto it=tt.begin();it!=tt.end();it++) { vecValue.push_back((*it).value); vecKey.push_back((*it).key); } saveModel(vecValue,vecKey); torch::autograd::Variable predData = Get_predData(data); torch::autograd::Variable fl = model.forward(predData); torch::autograd::Variable result = torch::squeeze(fl); torch::autograd::Variable rt = result.argmax(0); std::cout<<rt.size(0)<<std::endl; std::cout<<rt.size(1)<<std::endl; write2Txt(rt); } int main() { unet net(7,2); trainConvNet(net); return 0; }