数据增强代码
前面在Unet中提到过通过数据增强可以高效的利用网络。在Unet的代码里找到了有关数据增强的代码,这里贴出来。
下面个的为ImgToh5py.py 作用将相对应的图片数据集转为hdf5格式文件。
# -*- coding: utf-8 -*- import os import h5py import numpy as np from PIL import Image import cv2 def write_hdf5(arr,outfile): with h5py.File(outfile,"w") as f: f.create_dataset("image",data=arr,dtype=arr.dtype) #-------------Path of the images---------------- original_imgs_train="/home/chendali1/Gsj/DataEnhancement/DATA/train/" groundTruth_imgs_train="/home/chendali1/Gsj/DataEnhancement/DATA/label_train/" original_imgs_test="/home/chendali1/Gsj/DataEnhancement/DATA/test/" groundTruth_imgs_test="/home/chendali1/Gsj/DataEnhancement/DATA/label_test/" Nimgs_train=3#训练的图片数 Nimgs_test=2#测试的图片数 channels=1 height = 584 width=565 dataset_path ="/home/chendali1/Gsj/DataEnhancement/DATA/datasets_training_testing/"#将生成的.hdf5存放的地方 def get_datasets(imgs_dir,groundTruth_dir,train_test="null",Nimgs): #imgs=np.empty((Nimgs,height,width,channels)) imgs=np.empty((Nimgs,height,width)) groundTruth = np.empty((Nimgs,height,width)) for path,subdirs,files in os.walk(imgs_dir): for i in range(len(files)): print("original image: "+files[i]) img=Image.open(imgs_dir+files[i]) print(img) imgs[i]=np.asarray(img) groundTruth_name=files[i][0:2]+'png' print("groundTruth name: "+groundTruth_name) g_truth = Image.open(groundTruth_dir+groundTruth_name) print(g_truth) #g_truth = g_truth.convert("L") groundTruth[i]=np.asarray(g_truth) print("imgs max: "+str(np.max(imgs))) print("imgs min: "+str(np.min(imgs))) assert(np.max(groundTruth)==255.0 ) assert(np.min(groundTruth)==0.0) print("ground truth are correctly within pixel value range 0-1") #imgs = np.transpose(imgs,(0,3,1,2)) #imgs=np.transpose(imgs,(0,1,2)) #assert(imgs.shape==(Nimgs,channels,height,width)) groundTruth=np.reshape(groundTruth,(Nimgs,height,width)) #assert(groundTruth.shape==(Nimgs,1,height,width)) print('00000000000000000000000000') #print(imgs[0,:,:,:]) print(imgs[0,:,:]) print('11111111111111111111111111') print(groundTruth[0,:,:]) return imgs,groundTruth if not os.path.exists(dataset_path):#如果没有此文件夹就进行创建 os.makedirs(dataset_path) imgs_train,groundTruth_train = get_datasets(original_imgs_train,groundTruth_imgs_train,"train",Nimgs_train)#读取训练集的图片(原始图片+标签图片) print("saving train datasets") write_hdf5(imgs_train,dataset_path+"data_train.hdf5")#将训练集的原始转化为hdf5文件 write_hdf5(groundTruth_train,dataset_path+"data_groundTruth_train.hdf5")#将训练集的groudTruth转化为hdf5文件 imgs_test,groundTruth_test = get_datasets(original_imgs_test,groundTruth_imgs_test,"test",Nimgs_test)#读取测试集的图片(原始图片+标签图片) print("saving test datasets") write_hdf5(imgs_test,dataset_path+"data_test.hdf5")#将测试集的原始转化为hdf5文件 write_hdf5(groundTruth_test,dataset_path+"data_groundTruth_test.hdf5")#将测试集的groudTruth转化为hdf5文件
下面为DR.py 数据增强的主代码
from help_functions import load_hdf5 from pre_processing import my_PreProc import numpy as np import random import os import cv2 from PIL import * def data_consistency_check(imgs,masks): #assert(len(imgs.shape)==len(masks.shape)) assert(imgs.shape[0]==masks.shape[0]) assert(imgs.shape[2]==masks.shape[2]) #assert(imgs.shape[3]==masks.shape[3]) #assert(masks.shape[1]==1) assert(imgs.shape[1]==masks.shape[1]) #assert(imgs.shape[1]==3 or masks.shape[1]==1) def extract_random(full_imgs,full_masks,patch_h,patch_w,N_patches,inside=True): if(N_patches%full_imgs.shape[0]!=0): print("N_patches: please enter a multiple of 8") exit() #assert(len(full_imgs.shape)==4 and len(full_masks.shape)==4) #assert(full_imgs.shape[1]==1 or full_imgs.shape[1]==3) #assert(full_masks.shape[1]==1) #assert(full_imgs.shape[2]==full_masks.shape[2] and full_imgs.shape[3]==full_masks.shape[3]) patches=np.empty((N_patches,patch_h,patch_w)) patches_masks = np.empty((N_patches,patch_h,patch_w)) #********************shape[2] shape[3]************** img_h=full_imgs.shape[1] img_w=full_imgs.shape[2] patch_per_img=int(N_patches/full_imgs.shape[0])#计算出每张图片可以提取几张patch print("patches per full image: "+str(patch_per_img)) iter_tot=0 for i in range(full_imgs.shape[0]): k=0 while k<patch_per_img: x_center = random.randint(0+int(patch_w/2),img_w-int(patch_w/2)) y_center = random.randint(0+int(patch_h/2),img_h-int(patch_h/2)) #上述两句定义中心点的坐标 #patch=full_imgs[i,:,y_center-int(patch_h/2):y_center+int(patch_h/2),x_center-int(patch_w/2):x_center+int(patch_w/2)] patch=full_imgs[i,y_center-int(patch_h/2):y_center+int(patch_h/2),x_center-int(patch_w/2):x_center+int(patch_w/2)] patch_mask=full_masks[i,y_center-int(patch_h/2):y_center+int(patch_h/2),x_center-int(patch_w/2):x_center+int(patch_w/2)] patches[iter_tot]=patch patches_masks[iter_tot]=patch_mask iter_tot+=1 k+=1 return patches,patches_masks def get_data(data_train_imgs_orig, data_train_groundTruth, patch_height, patch_width, N_subimgs, inside_FOV): #data_train_imgs_orig 原始图片路径,data_train_groundTruth 原始图片标签路径,patch_height,patch_width分别表示代表要生成的patch的高和宽。 #N_subimgs代表要增强得到的数据的数目,inside_FOV代表是否在视野范围内进行提取,在DRIVE数据库上会用到,一般用不到设置为False train_imgs_orig = load_hdf5(data_train_imgs_orig) train_masks = load_hdf5(data_train_groundTruth)#这里的train_masks指的是groundTruth train_imgs=train_imgs_orig print(train_imgs_orig.shape) train_imgs = my_PreProc(train_imgs_orig)#对图像进行预先处理,包括一些图像变换处理等。 ##DRIVE 0-255 train_masks = train_masks data_consistency_check(train_imgs,train_masks)#判断训练图片与标签是否一致 #assert(np.min(train_masks)==0 and np.max(train_masks)==1) print("\ntrain images/mask shaoe:") print(train_imgs.shape) #print("train images range (min-max): "+str(np.min(train_imgs))+'-'+str(np.max(train_imgs))) #print("train masks are within 0-1\n") patches_imgs_train,patches_masks_train=extract_random(train_imgs,train_masks,patch_height,patch_width,N_subimgs,inside_FOV)#数据增强的主要功能,用于生成patch print("\ntrain PATCHES images/masks shape:") print(patches_imgs_train.shape) print("train PATCHES images range(min-max): "+str(np.min(patches_imgs_train))+'-'+str(np.max(patches_imgs_train))) return patches_imgs_train,patches_masks_train #原始图片数据集的存放路径 data_train_imgs_orig='/home/chendali1/Gsj/DataEnhancement/DATA/datasets_training_testing/data_train.hdf5' data_train_groundTruth='/home/chendali1/Gsj/DataEnhancement/DATA/datasets_training_testing/data_groundTruth_train.hdf5' data_test_imgs_orig='/home/chendali1/Gsj/DataEnhancement/DATA/datasets_training_testing/data_test.hdf5' data_test_groundTruth='/home/chendali1/Gsj/DataEnhancement/DATA/datasets_training_testing/data_groundTruth_test.hdf5' #生成的patch存放的位置 patches_path_train='/home/chendali1/Gsj/DataEnhancement/DATA/patches/patches_train/' patches_path_test='/home/chendali1/Gsj/DataEnhancement/DATA/patches/patches_test/' patches_path_label_train='/home/chendali1/Gsj/DataEnhancement/DATA/patches/patches_label_train/' patches_path_label_test='/home/chendali1/Gsj/DataEnhancement/DATA/patches/patches_label_test/' if not os.path.exists(patches_path_train): os.makedirs(patches_path_train) if not os.path.exists(patches_path_test): os.makedirs(patches_path_test) if not os.path.exists(patches_path_label_train): os.makedirs(patches_path_label_train) if not os.path.exists(patches_path_label_test): os.makedirs(patches_path_label_test) patches,patches_masks=get_data(data_train_imgs_orig,data_train_groundTruth,224,224,480,False) print(patches.shape) print(patches_masks.shape) #patches_train = np.transpose(patches,(0,2,3,1)) #patches_masks_train=np.transpose(patches_masks,(0,2,3,1)) patches_train=patches patches_masks_train=patches_masks #img=patches[0,:,:,:] #img0=patches_masks[0,:,:,:] print('111111111111111111111111') print(patches_train.shape) for i in range(patches_train.shape[0]): cv2.imwrite(patches_path_train+str(i)+'.jpg',patches_train[i,:,:]*255) cv2.imwrite(patches_path_label_train+str(i)+'.png',patches_masks_train[i,:,:]) patches,patches_masks=get_data(data_test_imgs_orig,data_test_groundTruth,224,224,240,False) print(patches.shape) print(patches_masks.shape) #patches_test = np.transpose(patches,(0,2,3,1)) #patches_masks_test=np.transpose(patches_masks,(0,2,3,1)) patches_test=patches patches_masks_test=patches_masks print('111111111111111111111111') print(patches_test.shape) #img=patches[0,:,:,:] #img0=patches_masks[0,:,:,:] for i in range(patches_test.shape[0]): cv2.imwrite(patches_path_test+str(i)+'.jpg',patches_test[i,:,:]*255) cv2.imwrite(patches_path_label_test+str(i)+'.png',patches_masks_test[i,:,:]) print("DataEnhancement Successfull!") #img=cv2.resize(img,(224,224),interpolation=cv2.INTER_CUBIC) #img0=cv2.resize(img0,(224,224),interpolation=cv2.INTER_CUBIC) """ print(img0) cv2.imshow('1',img) cv2.imshow('2',img0) cv2.waitKey(0) cv2.destroyAllWindows() """
下面为help_functions.py 主要包含一些辅助函数的定义
import h5py import numpy as np from PIL import Image from matplotlib import pyplot as plt def load_hdf5(infile):#用于加载hdf5文件 with h5py.File(infile,"r") as f: #"with" close the file after its nested commands return f["image"][()] def write_hdf5(arr,outfile):#用于生成hdf5文件 with h5py.File(outfile,"w") as f: f.create_dataset("image", data=arr, dtype=arr.dtype) #convert RGB image in black and white def rgb2gray(rgb): assert (len(rgb.shape)==4) #4D arrays assert (rgb.shape[1]==3) bn_imgs = rgb[:,0,:,:]*0.299 + rgb[:,1,:,:]*0.587 + rgb[:,2,:,:]*0.114 bn_imgs = np.reshape(bn_imgs,(rgb.shape[0],1,rgb.shape[2],rgb.shape[3])) return bn_imgs #group a set of images row per columns def group_images(data,per_row): assert data.shape[0]%per_row==0 assert (data.shape[1]==1 or data.shape[1]==3) data = np.transpose(data,(0,2,3,1)) #corect format for imshow all_stripe = [] for i in range(int(data.shape[0]/per_row)): stripe = data[i*per_row] for k in range(i*per_row+1, i*per_row+per_row): stripe = np.concatenate((stripe,data[k]),axis=1) all_stripe.append(stripe) totimg = all_stripe[0] for i in range(1,len(all_stripe)): totimg = np.concatenate((totimg,all_stripe[i]),axis=0) return totimg #visualize image (as PIL image, NOT as matplotlib!) def visualize(data,filename): assert (len(data.shape)==3) #height*width*channels img = None if data.shape[2]==1: #in case it is black and white data = np.reshape(data,(data.shape[0],data.shape[1])) if np.max(data)>1: img = Image.fromarray(data.astype(np.uint8)) #the image is already 0-255 else: img = Image.fromarray((data*255).astype(np.uint8)) #the image is between 0-1 img.save(filename + '.png') return img #prepare the mask in the right shape for the Unet def masks_Unet(masks): assert (len(masks.shape)==4) #4D arrays assert (masks.shape[1]==1 ) #check the channel is 1 im_h = masks.shape[2] im_w = masks.shape[3] masks = np.reshape(masks,(masks.shape[0],im_h*im_w)) new_masks = np.empty((masks.shape[0],im_h*im_w,2)) for i in range(masks.shape[0]): for j in range(im_h*im_w): if masks[i,j] == 0: new_masks[i,j,0]=1 new_masks[i,j,1]=0 else: new_masks[i,j,0]=0 new_masks[i,j,1]=1 return new_masks def pred_to_imgs(pred, patch_height, patch_width, mode="original"): assert (len(pred.shape)==3) #3D array: (Npatches,height*width,2) assert (pred.shape[2]==2 ) #check the classes are 2 pred_images = np.empty((pred.shape[0],pred.shape[1])) #(Npatches,height*width) if mode=="original": for i in range(pred.shape[0]): for pix in range(pred.shape[1]): pred_images[i,pix]=pred[i,pix,1] elif mode=="threshold": for i in range(pred.shape[0]): for pix in range(pred.shape[1]): if pred[i,pix,1]>=0.5: pred_images[i,pix]=1 else: pred_images[i,pix]=0 else: print ("mode " +str(mode) +" not recognized, it can be 'original' or 'threshold'") exit() pred_images = np.reshape(pred_images,(pred_images.shape[0],1, patch_height, patch_width)) return pred_images
pre_processing.py 图像预处理的文件
################################################### # # Script to pre-process the original imgs # ################################################## import numpy as np from PIL import Image import cv2 from help_functions import * #My pre processing (use for both training and testing!) def my_PreProc(data): #assert(len(data.shape)==4) #assert (data.shape[1]==3) #Use the original images #black-white conversion train_imgs=data #train_imgs = rgb2gray(data) #my preprocessing: #train_imgs = dataset_normalized(train_imgs) #train_imgs = clahe_equalized(train_imgs) #train_imgs = adjust_gamma(train_imgs, 1.2) train_imgs = train_imgs/255. #reduce to 0-1 range return train_imgs #============================================================ #========= PRE PROCESSING FUNCTIONS ========================# #============================================================ #==== histogram equalization def histo_equalized(imgs): assert (len(imgs.shape)==4) #4D arrays assert (imgs.shape[1]==1) #check the channel is 1 imgs_equalized = np.empty(imgs.shape) for i in range(imgs.shape[0]): imgs_equalized[i,0] = cv2.equalizeHist(np.array(imgs[i,0], dtype = np.uint8)) return imgs_equalized # CLAHE (Contrast Limited Adaptive Histogram Equalization) #adaptive histogram equalization is used. In this, image is divided into small blocks called "tiles" (tileSize is 8x8 by default in OpenCV). Then each of these blocks are histogram equalized as usual. So in a small area, histogram would confine to a small region (unless there is noise). If noise is there, it will be amplified. To avoid this, contrast limiting is applied. If any histogram bin is above the specified contrast limit (by default 40 in OpenCV), those pixels are clipped and distributed uniformly to other bins before applying histogram equalization. After equalization, to remove artifacts in tile borders, bilinear interpolation is applied def clahe_equalized(imgs): assert (len(imgs.shape)==4) #4D arrays assert (imgs.shape[1]==3) #check the channel is 1 #create a CLAHE object (Arguments are optional). clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) imgs_equalized = np.empty(imgs.shape) for i in range(imgs.shape[0]): imgs_equalized[i,0] = clahe.apply(np.array(imgs[i,0], dtype = np.uint8)) return imgs_equalized # ===== normalize over the dataset def dataset_normalized(imgs): assert (len(imgs.shape)==4) #4D arrays assert (imgs.shape[1]==3) #check the channel is 1 imgs_normalized = np.empty(imgs.shape) imgs_std = np.std(imgs) imgs_mean = np.mean(imgs) imgs_normalized = (imgs-imgs_mean)/imgs_std for i in range(imgs.shape[0]): imgs_normalized[i] = ((imgs_normalized[i] - np.min(imgs_normalized[i])) / (np.max(imgs_normalized[i])-np.min(imgs_normalized[i])))*255 return imgs_normalized def adjust_gamma(imgs, gamma=1.0): assert (len(imgs.shape)==4) #4D arrays assert (imgs.shape[1]==3) #check the channel is 1 # build a lookup table mapping the pixel values [0, 255] to # their adjusted gamma values invGamma = 1.0 / gamma table = np.array([((i / 255.0) ** invGamma) * 255 for i in np.arange(0, 256)]).astype("uint8") # apply gamma correction using the lookup table new_imgs = np.empty(imgs.shape) for i in range(imgs.shape[0]): new_imgs[i,0] = cv2.LUT(np.array(imgs[i,0], dtype = np.uint8), table) return new_imgs
#extract_patches.py 主要参考的原代码
import numpy as np import random import configparser as ConfigParser from help_functions import load_hdf5 from help_functions import visualize from help_functions import group_images from pre_processing import my_PreProc #To select the same images # random.seed(10) #Load the original data and return the extracted patches for training/testing def get_data_training(DRIVE_train_imgs_original, DRIVE_train_groudTruth, patch_height, patch_width, N_subimgs, inside_FOV): train_imgs_original = load_hdf5(DRIVE_train_imgs_original) train_masks = load_hdf5(DRIVE_train_groudTruth) #masks always the same # visualize(group_images(train_imgs_original[0:20,:,:,:],5),'imgs_train')#.show() #check original imgs train train_imgs = my_PreProc(train_imgs_original) train_masks = train_masks/255. train_imgs = train_imgs[:,:,9:574,:] #cut bottom and top so now it is 565*565 train_masks = train_masks[:,:,9:574,:] #cut bottom and top so now it is 565*565 data_consistency_check(train_imgs,train_masks) #check masks are within 0-1 assert(np.min(train_masks)==0 and np.max(train_masks)==1) print ("\ntrain images/masks shape:") print (train_imgs.shape) print ("train images range (min-max): " +str(np.min(train_imgs)) +' - '+str(np.max(train_imgs))) print ("train masks are within 0-1\n") #extract the TRAINING patches from the full images patches_imgs_train, patches_masks_train = extract_random(train_imgs,train_masks,patch_height,patch_width,N_subimgs,inside_FOV) data_consistency_check(patches_imgs_train, patches_masks_train) print ("\ntrain PATCHES images/masks shape:") print (patches_imgs_train.shape) print ("train PATCHES images range (min-max): " +str(np.min(patches_imgs_train)) +' - '+str(np.max(patches_imgs_train))) return patches_imgs_train, patches_masks_train#, patches_imgs_test, patches_masks_test #Load the original data and return the extracted patches for training/testing def get_data_testing(DRIVE_test_imgs_original, DRIVE_test_groudTruth, Imgs_to_test, patch_height, patch_width): ### test test_imgs_original = load_hdf5(DRIVE_test_imgs_original) test_masks = load_hdf5(DRIVE_test_groudTruth) test_imgs = my_PreProc(test_imgs_original) test_masks = test_masks/255. #extend both images and masks so they can be divided exactly by the patches dimensions test_imgs = test_imgs[0:Imgs_to_test,:,:,:] test_masks = test_masks[0:Imgs_to_test,:,:,:] test_imgs = paint_border(test_imgs,patch_height,patch_width) test_masks = paint_border(test_masks,patch_height,patch_width) data_consistency_check(test_imgs, test_masks) #check masks are within 0-1 assert(np.max(test_masks)==1 and np.min(test_masks)==0) print ("\ntest images/masks shape:") print (test_imgs.shape) print ("test images range (min-max): " +str(np.min(test_imgs)) +' - '+str(np.max(test_imgs))) print ("test masks are within 0-1\n") #extract the TEST patches from the full images patches_imgs_test = extract_ordered(test_imgs,patch_height,patch_width) patches_masks_test = extract_ordered(test_masks,patch_height,patch_width) data_consistency_check(patches_imgs_test, patches_masks_test) print ("\ntest PATCHES images/masks shape:") print (patches_imgs_test.shape) print ("test PATCHES images range (min-max): " +str(np.min(patches_imgs_test)) +' - '+str(np.max(patches_imgs_test))) return patches_imgs_test, patches_masks_test # Load the original data and return the extracted patches for testing # return the ground truth in its original shape def get_data_testing_overlap(DRIVE_test_imgs_original, DRIVE_test_groudTruth, Imgs_to_test, patch_height, patch_width, stride_height, stride_width): ### test test_imgs_original = load_hdf5(DRIVE_test_imgs_original) test_masks = load_hdf5(DRIVE_test_groudTruth) test_imgs = my_PreProc(test_imgs_original) test_masks = test_masks/255. #extend both images and masks so they can be divided exactly by the patches dimensions test_imgs = test_imgs[0:Imgs_to_test,:,:,:] test_masks = test_masks[0:Imgs_to_test,:,:,:] test_imgs = paint_border_overlap(test_imgs, patch_height, patch_width, stride_height, stride_width) #check masks are within 0-1 assert(np.max(test_masks)==1 and np.min(test_masks)==0) print ("\ntest images shape:") print (test_imgs.shape) print ("\ntest mask shape:") print (test_masks.shape) print ("test images range (min-max): " +str(np.min(test_imgs)) +' - '+str(np.max(test_imgs))) print ("test masks are within 0-1\n") #extract the TEST patches from the full images patches_imgs_test = extract_ordered_overlap(test_imgs,patch_height,patch_width,stride_height,stride_width) print ("\ntest PATCHES images shape:") print (patches_imgs_test.shape) print ("test PATCHES images range (min-max): " +str(np.min(patches_imgs_test)) +' - '+str(np.max(patches_imgs_test))) return patches_imgs_test, test_imgs.shape[2], test_imgs.shape[3], test_masks #data consinstency check def data_consistency_check(imgs,masks): assert(len(imgs.shape)==len(masks.shape)) assert(imgs.shape[0]==masks.shape[0]) assert(imgs.shape[2]==masks.shape[2]) assert(imgs.shape[3]==masks.shape[3]) assert(masks.shape[1]==1) assert(imgs.shape[1]==1 or imgs.shape[1]==3) #extract patches randomly in the full training images # -- Inside OR in full image def extract_random(full_imgs,full_masks, patch_h,patch_w, N_patches, inside=True): if (N_patches%full_imgs.shape[0] != 0): print ("N_patches: plase enter a multiple of 20") exit() assert (len(full_imgs.shape)==4 and len(full_masks.shape)==4) #4D arrays assert (full_imgs.shape[1]==1 or full_imgs.shape[1]==3) #check the channel is 1 or 3 assert (full_masks.shape[1]==1) #masks only black and white assert (full_imgs.shape[2] == full_masks.shape[2] and full_imgs.shape[3] == full_masks.shape[3]) patches = np.empty((N_patches,full_imgs.shape[1],patch_h,patch_w)) patches_masks = np.empty((N_patches,full_masks.shape[1],patch_h,patch_w)) img_h = full_imgs.shape[2] #height of the full image img_w = full_imgs.shape[3] #width of the full image # (0,0) in the center of the image patch_per_img = int(N_patches/full_imgs.shape[0]) #N_patches equally divided in the full images print ("patches per full image: " +str(patch_per_img)) iter_tot = 0 #iter over the total numbe rof patches (N_patches) for i in range(full_imgs.shape[0]): #loop over the full images k=0 while k <patch_per_img: x_center = random.randint(0+int(patch_w/2),img_w-int(patch_w/2)) # print "x_center " +str(x_center) y_center = random.randint(0+int(patch_h/2),img_h-int(patch_h/2)) # print "y_center " +str(y_center) #check whether the patch is fully contained in the FOV if inside==True: if is_patch_inside_FOV(x_center,y_center,img_w,img_h,patch_h)==False: continue patch = full_imgs[i,:,y_center-int(patch_h/2):y_center+int(patch_h/2),x_center-int(patch_w/2):x_center+int(patch_w/2)] patch_mask = full_masks[i,:,y_center-int(patch_h/2):y_center+int(patch_h/2),x_center-int(patch_w/2):x_center+int(patch_w/2)] patches[iter_tot]=patch patches_masks[iter_tot]=patch_mask iter_tot +=1 #total k+=1 #per full_img return patches, patches_masks #check if the patch is fully contained in the FOV def is_patch_inside_FOV(x,y,img_w,img_h,patch_h): x_ = x - int(img_w/2) # origin (0,0) shifted to image center y_ = y - int(img_h/2) # origin (0,0) shifted to image center R_inside = 270 - int(patch_h * np.sqrt(2.0) / 2.0) #radius is 270 (from DRIVE db docs), minus the patch diagonal (assumed it is a square #this is the limit to contain the full patch in the FOV radius = np.sqrt((x_*x_)+(y_*y_)) if radius < R_inside: return True else: return False #Divide all the full_imgs in pacthes def extract_ordered(full_imgs, patch_h, patch_w): assert (len(full_imgs.shape)==4) #4D arrays assert (full_imgs.shape[1]==1 or full_imgs.shape[1]==3) #check the channel is 1 or 3 img_h = full_imgs.shape[2] #height of the full image img_w = full_imgs.shape[3] #width of the full image N_patches_h = int(img_h/patch_h) #round to lowest int if (img_h%patch_h != 0): print ("warning: " +str(N_patches_h) +" patches in height, with about " +str(img_h%patch_h) +" pixels left over") N_patches_w = int(img_w/patch_w) #round to lowest int if (img_h%patch_h != 0): print ("warning: " +str(N_patches_w) +" patches in width, with about " +str(img_w%patch_w) +" pixels left over") print ("number of patches per image: " +str(N_patches_h*N_patches_w)) N_patches_tot = (N_patches_h*N_patches_w)*full_imgs.shape[0] patches = np.empty((N_patches_tot,full_imgs.shape[1],patch_h,patch_w)) iter_tot = 0 #iter over the total number of patches (N_patches) for i in range(full_imgs.shape[0]): #loop over the full images for h in range(N_patches_h): for w in range(N_patches_w): patch = full_imgs[i,:,h*patch_h:(h*patch_h)+patch_h,w*patch_w:(w*patch_w)+patch_w] patches[iter_tot]=patch iter_tot +=1 #total assert (iter_tot==N_patches_tot) return patches #array with all the full_imgs divided in patches def paint_border_overlap(full_imgs, patch_h, patch_w, stride_h, stride_w): assert (len(full_imgs.shape)==4) #4D arrays assert (full_imgs.shape[1]==1 or full_imgs.shape[1]==3) #check the channel is 1 or 3 img_h = full_imgs.shape[2] #height of the full image img_w = full_imgs.shape[3] #width of the full image leftover_h = (img_h-patch_h)%stride_h #leftover on the h dim leftover_w = (img_w-patch_w)%stride_w #leftover on the w dim if (leftover_h != 0): #change dimension of img_h print ("\nthe side H is not compatible with the selected stride of " +str(stride_h)) print ("img_h " +str(img_h) + ", patch_h " +str(patch_h) + ", stride_h " +str(stride_h)) print ("(img_h - patch_h) MOD stride_h: " +str(leftover_h)) print ("So the H dim will be padded with additional " +str(stride_h - leftover_h) + " pixels") tmp_full_imgs = np.zeros((full_imgs.shape[0],full_imgs.shape[1],img_h+(stride_h-leftover_h),img_w)) tmp_full_imgs[0:full_imgs.shape[0],0:full_imgs.shape[1],0:img_h,0:img_w] = full_imgs full_imgs = tmp_full_imgs if (leftover_w != 0): #change dimension of img_w print ("the side W is not compatible with the selected stride of " +str(stride_w)) print ("img_w " +str(img_w) + ", patch_w " +str(patch_w) + ", stride_w " +str(stride_w)) print ("(img_w - patch_w) MOD stride_w: " +str(leftover_w)) print ("So the W dim will be padded with additional " +str(stride_w - leftover_w) + " pixels") tmp_full_imgs = np.zeros((full_imgs.shape[0],full_imgs.shape[1],full_imgs.shape[2],img_w+(stride_w - leftover_w))) tmp_full_imgs[0:full_imgs.shape[0],0:full_imgs.shape[1],0:full_imgs.shape[2],0:img_w] = full_imgs full_imgs = tmp_full_imgs print ("new full images shape: \n" +str(full_imgs.shape)) return full_imgs #Divide all the full_imgs in pacthes def extract_ordered_overlap(full_imgs, patch_h, patch_w,stride_h,stride_w): assert (len(full_imgs.shape)==4) #4D arrays assert (full_imgs.shape[1]==1 or full_imgs.shape[1]==3) #check the channel is 1 or 3 img_h = full_imgs.shape[2] #height of the full image img_w = full_imgs.shape[3] #width of the full image assert ((img_h-patch_h)%stride_h==0 and (img_w-patch_w)%stride_w==0) N_patches_img = ((img_h-patch_h)//stride_h+1)*((img_w-patch_w)//stride_w+1) #// --> division between integers N_patches_tot = N_patches_img*full_imgs.shape[0] print ("Number of patches on h : " +str(((img_h-patch_h)//stride_h+1))) print ("Number of patches on w : " +str(((img_w-patch_w)//stride_w+1))) print ("number of patches per image: " +str(N_patches_img) +", totally for this dataset: " +str(N_patches_tot)) patches = np.empty((N_patches_tot,full_imgs.shape[1],patch_h,patch_w)) iter_tot = 0 #iter over the total number of patches (N_patches) for i in range(full_imgs.shape[0]): #loop over the full images for h in range((img_h-patch_h)//stride_h+1): for w in range((img_w-patch_w)//stride_w+1): patch = full_imgs[i,:,h*stride_h:(h*stride_h)+patch_h,w*stride_w:(w*stride_w)+patch_w] patches[iter_tot]=patch iter_tot +=1 #total assert (iter_tot==N_patches_tot) return patches #array with all the full_imgs divided in patches def recompone_overlap(preds, img_h, img_w, stride_h, stride_w): assert (len(preds.shape)==4) #4D arrays assert (preds.shape[1]==1 or preds.shape[1]==3) #check the channel is 1 or 3 patch_h = preds.shape[2] patch_w = preds.shape[3] N_patches_h = (img_h-patch_h)//stride_h+1 N_patches_w = (img_w-patch_w)//stride_w+1 N_patches_img = N_patches_h * N_patches_w print ("N_patches_h: " +str(N_patches_h)) print ("N_patches_w: " +str(N_patches_w)) print ("N_patches_img: " +str(N_patches_img)) assert (preds.shape[0]%N_patches_img==0) N_full_imgs = preds.shape[0]//N_patches_img print ("According to the dimension inserted, there are " +str(N_full_imgs) +" full images (of " +str(img_h)+"x" +str(img_w) +" each)") full_prob = np.zeros((N_full_imgs,preds.shape[1],img_h,img_w)) #itialize to zero mega array with sum of Probabilities full_sum = np.zeros((N_full_imgs,preds.shape[1],img_h,img_w)) k = 0 #iterator over all the patches for i in range(N_full_imgs): for h in range((img_h-patch_h)//stride_h+1): for w in range((img_w-patch_w)//stride_w+1): full_prob[i,:,h*stride_h:(h*stride_h)+patch_h,w*stride_w:(w*stride_w)+patch_w]+=preds[k] full_sum[i,:,h*stride_h:(h*stride_h)+patch_h,w*stride_w:(w*stride_w)+patch_w]+=1 k+=1 assert(k==preds.shape[0]) assert(np.min(full_sum)>=1.0) #at least one final_avg = full_prob/full_sum print (final_avg.shape) assert(np.max(final_avg)<=1.0) #max value for a pixel is 1.0 assert(np.min(final_avg)>=0.0) #min value for a pixel is 0.0 return final_avg #Recompone the full images with the patches def recompone(data,N_h,N_w): assert (data.shape[1]==1 or data.shape[1]==3) #check the channel is 1 or 3 assert(len(data.shape)==4) N_pacth_per_img = N_w*N_h assert(data.shape[0]%N_pacth_per_img == 0) N_full_imgs = data.shape[0]/N_pacth_per_img patch_h = data.shape[2] patch_w = data.shape[3] N_pacth_per_img = N_w*N_h #define and start full recompone full_recomp = np.empty((N_full_imgs,data.shape[1],N_h*patch_h,N_w*patch_w)) k = 0 #iter full img s = 0 #iter single patch while (s<data.shape[0]): #recompone one: single_recon = np.empty((data.shape[1],N_h*patch_h,N_w*patch_w)) for h in range(N_h): for w in range(N_w): single_recon[:,h*patch_h:(h*patch_h)+patch_h,w*patch_w:(w*patch_w)+patch_w]=data[s] s+=1 full_recomp[k]=single_recon k+=1 assert (k==N_full_imgs) return full_recomp #Extend the full images because patch divison is not exact def paint_border(data,patch_h,patch_w): assert (len(data.shape)==4) #4D arrays assert (data.shape[1]==1 or data.shape[1]==3) #check the channel is 1 or 3 img_h=data.shape[2] img_w=data.shape[3] new_img_h = 0 new_img_w = 0 if (img_h%patch_h)==0: new_img_h = img_h else: new_img_h = ((int(img_h)/int(patch_h))+1)*patch_h if (img_w%patch_w)==0: new_img_w = img_w else: new_img_w = ((int(img_w)/int(patch_w))+1)*patch_w new_data = np.zeros((data.shape[0],data.shape[1],new_img_h,new_img_w)) new_data[:,:,0:img_h,0:img_w] = data[:,:,:,:] return new_data #return only the pixels contained in the FOV, for both images and masks def pred_only_FOV(data_imgs,data_masks,original_imgs_border_masks): assert (len(data_imgs.shape)==4 and len(data_masks.shape)==4) #4D arrays assert (data_imgs.shape[0]==data_masks.shape[0]) assert (data_imgs.shape[2]==data_masks.shape[2]) assert (data_imgs.shape[3]==data_masks.shape[3]) assert (data_imgs.shape[1]==1 and data_masks.shape[1]==1) #check the channel is 1 height = data_imgs.shape[2] width = data_imgs.shape[3] new_pred_imgs = [] new_pred_masks = [] for i in range(data_imgs.shape[0]): #loop over the full images for x in range(width): for y in range(height): if inside_FOV_DRIVE(i,x,y,original_imgs_border_masks)==True: new_pred_imgs.append(data_imgs[i,:,y,x]) new_pred_masks.append(data_masks[i,:,y,x]) new_pred_imgs = np.asarray(new_pred_imgs) new_pred_masks = np.asarray(new_pred_masks) return new_pred_imgs, new_pred_masks #function to set to black everything outside the FOV, in a full image def kill_border(data, original_imgs_border_masks): assert (len(data.shape)==4) #4D arrays assert (data.shape[1]==1 or data.shape[1]==3) #check the channel is 1 or 3 height = data.shape[2] width = data.shape[3] for i in range(data.shape[0]): #loop over the full images for x in range(width): for y in range(height): if inside_FOV_DRIVE(i,x,y,original_imgs_border_masks)==False: data[i,:,y,x]=0.0 def inside_FOV_DRIVE(i, x, y, DRIVE_masks): assert (len(DRIVE_masks.shape)==4) #4D arrays assert (DRIVE_masks.shape[1]==1) #DRIVE masks is black and white # DRIVE_masks = DRIVE_masks/255. #NOOO!! otherwise with float numbers takes forever!! if (x >= DRIVE_masks.shape[3] or y >= DRIVE_masks.shape[2]): #my image bigger than the original return False if (DRIVE_masks[i,0,y,x]>0): #0==black pixels # print DRIVE_masks[i,0,y,x] #verify it is working right return True else: return False