数据增强代码

    前面在Unet中提到过通过数据增强可以高效的利用网络。在Unet的代码里找到了有关数据增强的代码,这里贴出来。

下面个的为ImgToh5py.py 作用将相对应的图片数据集转为hdf5格式文件。

# -*- coding: utf-8 -*-
import os
import h5py
import numpy as np
from PIL import Image
import cv2

def write_hdf5(arr,outfile):
    with h5py.File(outfile,"w") as f:
        f.create_dataset("image",data=arr,dtype=arr.dtype)



#-------------Path of the images----------------
original_imgs_train="/home/chendali1/Gsj/DataEnhancement/DATA/train/"
groundTruth_imgs_train="/home/chendali1/Gsj/DataEnhancement/DATA/label_train/"

original_imgs_test="/home/chendali1/Gsj/DataEnhancement/DATA/test/"
groundTruth_imgs_test="/home/chendali1/Gsj/DataEnhancement/DATA/label_test/"

Nimgs_train=3#训练的图片数
Nimgs_test=2#测试的图片数
channels=1
height = 584
width=565

dataset_path ="/home/chendali1/Gsj/DataEnhancement/DATA/datasets_training_testing/"#将生成的.hdf5存放的地方
def get_datasets(imgs_dir,groundTruth_dir,train_test="null",Nimgs):
    #imgs=np.empty((Nimgs,height,width,channels))
    imgs=np.empty((Nimgs,height,width))
    groundTruth = np.empty((Nimgs,height,width))
    for path,subdirs,files in os.walk(imgs_dir):
        for i in range(len(files)):
            print("original image: "+files[i])
            img=Image.open(imgs_dir+files[i])
            print(img)
            imgs[i]=np.asarray(img)
            
            groundTruth_name=files[i][0:2]+'png'
            print("groundTruth name: "+groundTruth_name)
            g_truth = Image.open(groundTruth_dir+groundTruth_name)
            print(g_truth)
            #g_truth = g_truth.convert("L")
            groundTruth[i]=np.asarray(g_truth)
    print("imgs max: "+str(np.max(imgs)))
    print("imgs min: "+str(np.min(imgs)))
    assert(np.max(groundTruth)==255.0 )
    assert(np.min(groundTruth)==0.0)
    print("ground truth are correctly within pixel value range 0-1")
    #imgs = np.transpose(imgs,(0,3,1,2))
    #imgs=np.transpose(imgs,(0,1,2))
    #assert(imgs.shape==(Nimgs,channels,height,width))
    groundTruth=np.reshape(groundTruth,(Nimgs,height,width))
    #assert(groundTruth.shape==(Nimgs,1,height,width))
    print('00000000000000000000000000')
    #print(imgs[0,:,:,:])
    print(imgs[0,:,:])
    print('11111111111111111111111111')
    print(groundTruth[0,:,:])
    return imgs,groundTruth


if not os.path.exists(dataset_path):#如果没有此文件夹就进行创建
    os.makedirs(dataset_path)

imgs_train,groundTruth_train = get_datasets(original_imgs_train,groundTruth_imgs_train,"train",Nimgs_train)#读取训练集的图片(原始图片+标签图片)
print("saving train datasets")
write_hdf5(imgs_train,dataset_path+"data_train.hdf5")#将训练集的原始转化为hdf5文件
write_hdf5(groundTruth_train,dataset_path+"data_groundTruth_train.hdf5")#将训练集的groudTruth转化为hdf5文件

imgs_test,groundTruth_test = get_datasets(original_imgs_test,groundTruth_imgs_test,"test",Nimgs_test)#读取测试集的图片(原始图片+标签图片)
print("saving test datasets")
write_hdf5(imgs_test,dataset_path+"data_test.hdf5")#将测试集的原始转化为hdf5文件
write_hdf5(groundTruth_test,dataset_path+"data_groundTruth_test.hdf5")#将测试集的groudTruth转化为hdf5文件

下面为DR.py 数据增强的主代码

from help_functions import load_hdf5
from pre_processing import my_PreProc
import numpy as np
import random
import os
import cv2
from PIL import *

def data_consistency_check(imgs,masks):
    #assert(len(imgs.shape)==len(masks.shape))
    assert(imgs.shape[0]==masks.shape[0])
    assert(imgs.shape[2]==masks.shape[2])
    #assert(imgs.shape[3]==masks.shape[3])
    #assert(masks.shape[1]==1)
    assert(imgs.shape[1]==masks.shape[1])
    #assert(imgs.shape[1]==3 or masks.shape[1]==1)
def extract_random(full_imgs,full_masks,patch_h,patch_w,N_patches,inside=True):
    if(N_patches%full_imgs.shape[0]!=0):
        print("N_patches: please enter a multiple of 8")
        exit()
    #assert(len(full_imgs.shape)==4 and len(full_masks.shape)==4)
    #assert(full_imgs.shape[1]==1 or full_imgs.shape[1]==3)
    #assert(full_masks.shape[1]==1)
    #assert(full_imgs.shape[2]==full_masks.shape[2] and full_imgs.shape[3]==full_masks.shape[3])
    patches=np.empty((N_patches,patch_h,patch_w))
    patches_masks = np.empty((N_patches,patch_h,patch_w))
    #********************shape[2] shape[3]**************
    img_h=full_imgs.shape[1]
    img_w=full_imgs.shape[2]
    patch_per_img=int(N_patches/full_imgs.shape[0])#计算出每张图片可以提取几张patch
    print("patches per full image: "+str(patch_per_img))
    iter_tot=0
    for i in range(full_imgs.shape[0]):
        k=0
        while k<patch_per_img:
            x_center = random.randint(0+int(patch_w/2),img_w-int(patch_w/2))
            y_center = random.randint(0+int(patch_h/2),img_h-int(patch_h/2))
            #上述两句定义中心点的坐标
            #patch=full_imgs[i,:,y_center-int(patch_h/2):y_center+int(patch_h/2),x_center-int(patch_w/2):x_center+int(patch_w/2)]
            patch=full_imgs[i,y_center-int(patch_h/2):y_center+int(patch_h/2),x_center-int(patch_w/2):x_center+int(patch_w/2)]
            patch_mask=full_masks[i,y_center-int(patch_h/2):y_center+int(patch_h/2),x_center-int(patch_w/2):x_center+int(patch_w/2)]
            patches[iter_tot]=patch
            patches_masks[iter_tot]=patch_mask
            iter_tot+=1
            k+=1
    return patches,patches_masks
def get_data(data_train_imgs_orig,
             data_train_groundTruth,
             patch_height,
             patch_width,
             N_subimgs,
             inside_FOV):
    #data_train_imgs_orig 原始图片路径,data_train_groundTruth 原始图片标签路径,patch_height,patch_width分别表示代表要生成的patch的高和宽。
    #N_subimgs代表要增强得到的数据的数目,inside_FOV代表是否在视野范围内进行提取,在DRIVE数据库上会用到,一般用不到设置为False
    train_imgs_orig = load_hdf5(data_train_imgs_orig)
    train_masks = load_hdf5(data_train_groundTruth)#这里的train_masks指的是groundTruth
  
    train_imgs=train_imgs_orig
    print(train_imgs_orig.shape)
    train_imgs = my_PreProc(train_imgs_orig)#对图像进行预先处理,包括一些图像变换处理等。
    
    ##DRIVE 0-255
    train_masks = train_masks
    data_consistency_check(train_imgs,train_masks)#判断训练图片与标签是否一致

    #assert(np.min(train_masks)==0 and np.max(train_masks)==1)

    print("\ntrain images/mask shaoe:")
    print(train_imgs.shape)
    #print("train images range (min-max): "+str(np.min(train_imgs))+'-'+str(np.max(train_imgs)))
    #print("train masks are within 0-1\n")

    patches_imgs_train,patches_masks_train=extract_random(train_imgs,train_masks,patch_height,patch_width,N_subimgs,inside_FOV)#数据增强的主要功能,用于生成patch
    print("\ntrain PATCHES images/masks shape:")
    print(patches_imgs_train.shape)
    print("train PATCHES images range(min-max): "+str(np.min(patches_imgs_train))+'-'+str(np.max(patches_imgs_train)))
    return patches_imgs_train,patches_masks_train
#原始图片数据集的存放路径
data_train_imgs_orig='/home/chendali1/Gsj/DataEnhancement/DATA/datasets_training_testing/data_train.hdf5'
data_train_groundTruth='/home/chendali1/Gsj/DataEnhancement/DATA/datasets_training_testing/data_groundTruth_train.hdf5'
data_test_imgs_orig='/home/chendali1/Gsj/DataEnhancement/DATA/datasets_training_testing/data_test.hdf5'
data_test_groundTruth='/home/chendali1/Gsj/DataEnhancement/DATA/datasets_training_testing/data_groundTruth_test.hdf5'

#生成的patch存放的位置
patches_path_train='/home/chendali1/Gsj/DataEnhancement/DATA/patches/patches_train/'
patches_path_test='/home/chendali1/Gsj/DataEnhancement/DATA/patches/patches_test/'
patches_path_label_train='/home/chendali1/Gsj/DataEnhancement/DATA/patches/patches_label_train/'
patches_path_label_test='/home/chendali1/Gsj/DataEnhancement/DATA/patches/patches_label_test/'

if not os.path.exists(patches_path_train):
    os.makedirs(patches_path_train)
if not os.path.exists(patches_path_test):
    os.makedirs(patches_path_test)
if not os.path.exists(patches_path_label_train):
    os.makedirs(patches_path_label_train)
if not os.path.exists(patches_path_label_test):
    os.makedirs(patches_path_label_test)

patches,patches_masks=get_data(data_train_imgs_orig,data_train_groundTruth,224,224,480,False)
print(patches.shape)
print(patches_masks.shape)
#patches_train = np.transpose(patches,(0,2,3,1))
#patches_masks_train=np.transpose(patches_masks,(0,2,3,1))
patches_train=patches
patches_masks_train=patches_masks
#img=patches[0,:,:,:]
#img0=patches_masks[0,:,:,:]
print('111111111111111111111111')
print(patches_train.shape)
for i in range(patches_train.shape[0]):
    cv2.imwrite(patches_path_train+str(i)+'.jpg',patches_train[i,:,:]*255)
    cv2.imwrite(patches_path_label_train+str(i)+'.png',patches_masks_train[i,:,:])

patches,patches_masks=get_data(data_test_imgs_orig,data_test_groundTruth,224,224,240,False)
print(patches.shape)
print(patches_masks.shape)

#patches_test = np.transpose(patches,(0,2,3,1))
#patches_masks_test=np.transpose(patches_masks,(0,2,3,1))
patches_test=patches
patches_masks_test=patches_masks
print('111111111111111111111111')
print(patches_test.shape)
#img=patches[0,:,:,:]
#img0=patches_masks[0,:,:,:]
for i in range(patches_test.shape[0]):

    cv2.imwrite(patches_path_test+str(i)+'.jpg',patches_test[i,:,:]*255)
    cv2.imwrite(patches_path_label_test+str(i)+'.png',patches_masks_test[i,:,:])
print("DataEnhancement Successfull!")

#img=cv2.resize(img,(224,224),interpolation=cv2.INTER_CUBIC)
#img0=cv2.resize(img0,(224,224),interpolation=cv2.INTER_CUBIC)
"""
print(img0)
cv2.imshow('1',img)

cv2.imshow('2',img0)
cv2.waitKey(0)
cv2.destroyAllWindows()
"""

下面为help_functions.py 主要包含一些辅助函数的定义

import h5py
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt

def load_hdf5(infile):#用于加载hdf5文件
  with h5py.File(infile,"r") as f:  #"with" close the file after its nested commands
    return f["image"][()]

def write_hdf5(arr,outfile):#用于生成hdf5文件
  with h5py.File(outfile,"w") as f:
    f.create_dataset("image", data=arr, dtype=arr.dtype)

#convert RGB image in black and white
def rgb2gray(rgb):
    assert (len(rgb.shape)==4)  #4D arrays
    assert (rgb.shape[1]==3)
    bn_imgs = rgb[:,0,:,:]*0.299 + rgb[:,1,:,:]*0.587 + rgb[:,2,:,:]*0.114
    bn_imgs = np.reshape(bn_imgs,(rgb.shape[0],1,rgb.shape[2],rgb.shape[3]))
    return bn_imgs

#group a set of images row per columns
def group_images(data,per_row):
    assert data.shape[0]%per_row==0
    assert (data.shape[1]==1 or data.shape[1]==3)
    data = np.transpose(data,(0,2,3,1))  #corect format for imshow
    all_stripe = []
    for i in range(int(data.shape[0]/per_row)):
        stripe = data[i*per_row]
        for k in range(i*per_row+1, i*per_row+per_row):
            stripe = np.concatenate((stripe,data[k]),axis=1)
        all_stripe.append(stripe)
    totimg = all_stripe[0]
    for i in range(1,len(all_stripe)):
        totimg = np.concatenate((totimg,all_stripe[i]),axis=0)
    return totimg


#visualize image (as PIL image, NOT as matplotlib!)
def visualize(data,filename):
    assert (len(data.shape)==3) #height*width*channels
    img = None
    if data.shape[2]==1:  #in case it is black and white
        data = np.reshape(data,(data.shape[0],data.shape[1]))
    if np.max(data)>1:
        img = Image.fromarray(data.astype(np.uint8))   #the image is already 0-255
    else:
        img = Image.fromarray((data*255).astype(np.uint8))  #the image is between 0-1
    img.save(filename + '.png')
    return img


#prepare the mask in the right shape for the Unet
def masks_Unet(masks):
    assert (len(masks.shape)==4)  #4D arrays
    assert (masks.shape[1]==1 )  #check the channel is 1
    im_h = masks.shape[2]
    im_w = masks.shape[3]
    masks = np.reshape(masks,(masks.shape[0],im_h*im_w))
    new_masks = np.empty((masks.shape[0],im_h*im_w,2))
    for i in range(masks.shape[0]):
        for j in range(im_h*im_w):
            if  masks[i,j] == 0:
                new_masks[i,j,0]=1
                new_masks[i,j,1]=0
            else:
                new_masks[i,j,0]=0
                new_masks[i,j,1]=1
    return new_masks


def pred_to_imgs(pred, patch_height, patch_width, mode="original"):
    assert (len(pred.shape)==3)  #3D array: (Npatches,height*width,2)
    assert (pred.shape[2]==2 )  #check the classes are 2
    pred_images = np.empty((pred.shape[0],pred.shape[1]))  #(Npatches,height*width)
    if mode=="original":
        for i in range(pred.shape[0]):
            for pix in range(pred.shape[1]):
                pred_images[i,pix]=pred[i,pix,1]
    elif mode=="threshold":
        for i in range(pred.shape[0]):
            for pix in range(pred.shape[1]):
                if pred[i,pix,1]>=0.5:
                    pred_images[i,pix]=1
                else:
                    pred_images[i,pix]=0
    else:
        print ("mode " +str(mode) +" not recognized, it can be 'original' or 'threshold'")
        exit()
    pred_images = np.reshape(pred_images,(pred_images.shape[0],1, patch_height, patch_width))
    return pred_images

pre_processing.py 图像预处理的文件

###################################################
#
#   Script to pre-process the original imgs
#
##################################################


import numpy as np
from PIL import Image
import cv2

from help_functions import *


#My pre processing (use for both training and testing!)
def my_PreProc(data):
    #assert(len(data.shape)==4)
    #assert (data.shape[1]==3)  #Use the original images
    #black-white conversion
    train_imgs=data
    #train_imgs = rgb2gray(data)
    #my preprocessing:
    #train_imgs = dataset_normalized(train_imgs)
    #train_imgs = clahe_equalized(train_imgs)
    #train_imgs = adjust_gamma(train_imgs, 1.2)
    train_imgs = train_imgs/255.  #reduce to 0-1 range
    return train_imgs


#============================================================
#========= PRE PROCESSING FUNCTIONS ========================#
#============================================================

#==== histogram equalization
def histo_equalized(imgs):
    assert (len(imgs.shape)==4)  #4D arrays
    assert (imgs.shape[1]==1)  #check the channel is 1
    imgs_equalized = np.empty(imgs.shape)
    for i in range(imgs.shape[0]):
        imgs_equalized[i,0] = cv2.equalizeHist(np.array(imgs[i,0], dtype = np.uint8))
    return imgs_equalized


# CLAHE (Contrast Limited Adaptive Histogram Equalization)
#adaptive histogram equalization is used. In this, image is divided into small blocks called "tiles" (tileSize is 8x8 by default in OpenCV). Then each of these blocks are histogram equalized as usual. So in a small area, histogram would confine to a small region (unless there is noise). If noise is there, it will be amplified. To avoid this, contrast limiting is applied. If any histogram bin is above the specified contrast limit (by default 40 in OpenCV), those pixels are clipped and distributed uniformly to other bins before applying histogram equalization. After equalization, to remove artifacts in tile borders, bilinear interpolation is applied
def clahe_equalized(imgs):
    assert (len(imgs.shape)==4)  #4D arrays
    assert (imgs.shape[1]==3)  #check the channel is 1
    #create a CLAHE object (Arguments are optional).
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    imgs_equalized = np.empty(imgs.shape)
    for i in range(imgs.shape[0]):
        imgs_equalized[i,0] = clahe.apply(np.array(imgs[i,0], dtype = np.uint8))
    return imgs_equalized


# ===== normalize over the dataset
def dataset_normalized(imgs):
    assert (len(imgs.shape)==4)  #4D arrays
    assert (imgs.shape[1]==3)  #check the channel is 1
    imgs_normalized = np.empty(imgs.shape)
    imgs_std = np.std(imgs)
    imgs_mean = np.mean(imgs)
    imgs_normalized = (imgs-imgs_mean)/imgs_std
    for i in range(imgs.shape[0]):
        imgs_normalized[i] = ((imgs_normalized[i] - np.min(imgs_normalized[i])) / (np.max(imgs_normalized[i])-np.min(imgs_normalized[i])))*255
    return imgs_normalized


def adjust_gamma(imgs, gamma=1.0):
    assert (len(imgs.shape)==4)  #4D arrays
    assert (imgs.shape[1]==3)  #check the channel is 1
    # build a lookup table mapping the pixel values [0, 255] to
    # their adjusted gamma values
    invGamma = 1.0 / gamma
    table = np.array([((i / 255.0) ** invGamma) * 255 for i in np.arange(0, 256)]).astype("uint8")
    # apply gamma correction using the lookup table
    new_imgs = np.empty(imgs.shape)
    for i in range(imgs.shape[0]):
        new_imgs[i,0] = cv2.LUT(np.array(imgs[i,0], dtype = np.uint8), table)
    return new_imgs
#extract_patches.py 主要参考的原代码
import
numpy as np import random import configparser as ConfigParser from help_functions import load_hdf5 from help_functions import visualize from help_functions import group_images from pre_processing import my_PreProc #To select the same images # random.seed(10) #Load the original data and return the extracted patches for training/testing def get_data_training(DRIVE_train_imgs_original, DRIVE_train_groudTruth, patch_height, patch_width, N_subimgs, inside_FOV): train_imgs_original = load_hdf5(DRIVE_train_imgs_original) train_masks = load_hdf5(DRIVE_train_groudTruth) #masks always the same # visualize(group_images(train_imgs_original[0:20,:,:,:],5),'imgs_train')#.show() #check original imgs train train_imgs = my_PreProc(train_imgs_original) train_masks = train_masks/255. train_imgs = train_imgs[:,:,9:574,:] #cut bottom and top so now it is 565*565 train_masks = train_masks[:,:,9:574,:] #cut bottom and top so now it is 565*565 data_consistency_check(train_imgs,train_masks) #check masks are within 0-1 assert(np.min(train_masks)==0 and np.max(train_masks)==1) print ("\ntrain images/masks shape:") print (train_imgs.shape) print ("train images range (min-max): " +str(np.min(train_imgs)) +' - '+str(np.max(train_imgs))) print ("train masks are within 0-1\n") #extract the TRAINING patches from the full images patches_imgs_train, patches_masks_train = extract_random(train_imgs,train_masks,patch_height,patch_width,N_subimgs,inside_FOV) data_consistency_check(patches_imgs_train, patches_masks_train) print ("\ntrain PATCHES images/masks shape:") print (patches_imgs_train.shape) print ("train PATCHES images range (min-max): " +str(np.min(patches_imgs_train)) +' - '+str(np.max(patches_imgs_train))) return patches_imgs_train, patches_masks_train#, patches_imgs_test, patches_masks_test #Load the original data and return the extracted patches for training/testing def get_data_testing(DRIVE_test_imgs_original, DRIVE_test_groudTruth, Imgs_to_test, patch_height, patch_width): ### test test_imgs_original = load_hdf5(DRIVE_test_imgs_original) test_masks = load_hdf5(DRIVE_test_groudTruth) test_imgs = my_PreProc(test_imgs_original) test_masks = test_masks/255. #extend both images and masks so they can be divided exactly by the patches dimensions test_imgs = test_imgs[0:Imgs_to_test,:,:,:] test_masks = test_masks[0:Imgs_to_test,:,:,:] test_imgs = paint_border(test_imgs,patch_height,patch_width) test_masks = paint_border(test_masks,patch_height,patch_width) data_consistency_check(test_imgs, test_masks) #check masks are within 0-1 assert(np.max(test_masks)==1 and np.min(test_masks)==0) print ("\ntest images/masks shape:") print (test_imgs.shape) print ("test images range (min-max): " +str(np.min(test_imgs)) +' - '+str(np.max(test_imgs))) print ("test masks are within 0-1\n") #extract the TEST patches from the full images patches_imgs_test = extract_ordered(test_imgs,patch_height,patch_width) patches_masks_test = extract_ordered(test_masks,patch_height,patch_width) data_consistency_check(patches_imgs_test, patches_masks_test) print ("\ntest PATCHES images/masks shape:") print (patches_imgs_test.shape) print ("test PATCHES images range (min-max): " +str(np.min(patches_imgs_test)) +' - '+str(np.max(patches_imgs_test))) return patches_imgs_test, patches_masks_test # Load the original data and return the extracted patches for testing # return the ground truth in its original shape def get_data_testing_overlap(DRIVE_test_imgs_original, DRIVE_test_groudTruth, Imgs_to_test, patch_height, patch_width, stride_height, stride_width): ### test test_imgs_original = load_hdf5(DRIVE_test_imgs_original) test_masks = load_hdf5(DRIVE_test_groudTruth) test_imgs = my_PreProc(test_imgs_original) test_masks = test_masks/255. #extend both images and masks so they can be divided exactly by the patches dimensions test_imgs = test_imgs[0:Imgs_to_test,:,:,:] test_masks = test_masks[0:Imgs_to_test,:,:,:] test_imgs = paint_border_overlap(test_imgs, patch_height, patch_width, stride_height, stride_width) #check masks are within 0-1 assert(np.max(test_masks)==1 and np.min(test_masks)==0) print ("\ntest images shape:") print (test_imgs.shape) print ("\ntest mask shape:") print (test_masks.shape) print ("test images range (min-max): " +str(np.min(test_imgs)) +' - '+str(np.max(test_imgs))) print ("test masks are within 0-1\n") #extract the TEST patches from the full images patches_imgs_test = extract_ordered_overlap(test_imgs,patch_height,patch_width,stride_height,stride_width) print ("\ntest PATCHES images shape:") print (patches_imgs_test.shape) print ("test PATCHES images range (min-max): " +str(np.min(patches_imgs_test)) +' - '+str(np.max(patches_imgs_test))) return patches_imgs_test, test_imgs.shape[2], test_imgs.shape[3], test_masks #data consinstency check def data_consistency_check(imgs,masks): assert(len(imgs.shape)==len(masks.shape)) assert(imgs.shape[0]==masks.shape[0]) assert(imgs.shape[2]==masks.shape[2]) assert(imgs.shape[3]==masks.shape[3]) assert(masks.shape[1]==1) assert(imgs.shape[1]==1 or imgs.shape[1]==3) #extract patches randomly in the full training images # -- Inside OR in full image def extract_random(full_imgs,full_masks, patch_h,patch_w, N_patches, inside=True): if (N_patches%full_imgs.shape[0] != 0): print ("N_patches: plase enter a multiple of 20") exit() assert (len(full_imgs.shape)==4 and len(full_masks.shape)==4) #4D arrays assert (full_imgs.shape[1]==1 or full_imgs.shape[1]==3) #check the channel is 1 or 3 assert (full_masks.shape[1]==1) #masks only black and white assert (full_imgs.shape[2] == full_masks.shape[2] and full_imgs.shape[3] == full_masks.shape[3]) patches = np.empty((N_patches,full_imgs.shape[1],patch_h,patch_w)) patches_masks = np.empty((N_patches,full_masks.shape[1],patch_h,patch_w)) img_h = full_imgs.shape[2] #height of the full image img_w = full_imgs.shape[3] #width of the full image # (0,0) in the center of the image patch_per_img = int(N_patches/full_imgs.shape[0]) #N_patches equally divided in the full images print ("patches per full image: " +str(patch_per_img)) iter_tot = 0 #iter over the total numbe rof patches (N_patches) for i in range(full_imgs.shape[0]): #loop over the full images k=0 while k <patch_per_img: x_center = random.randint(0+int(patch_w/2),img_w-int(patch_w/2)) # print "x_center " +str(x_center) y_center = random.randint(0+int(patch_h/2),img_h-int(patch_h/2)) # print "y_center " +str(y_center) #check whether the patch is fully contained in the FOV if inside==True: if is_patch_inside_FOV(x_center,y_center,img_w,img_h,patch_h)==False: continue patch = full_imgs[i,:,y_center-int(patch_h/2):y_center+int(patch_h/2),x_center-int(patch_w/2):x_center+int(patch_w/2)] patch_mask = full_masks[i,:,y_center-int(patch_h/2):y_center+int(patch_h/2),x_center-int(patch_w/2):x_center+int(patch_w/2)] patches[iter_tot]=patch patches_masks[iter_tot]=patch_mask iter_tot +=1 #total k+=1 #per full_img return patches, patches_masks #check if the patch is fully contained in the FOV def is_patch_inside_FOV(x,y,img_w,img_h,patch_h): x_ = x - int(img_w/2) # origin (0,0) shifted to image center y_ = y - int(img_h/2) # origin (0,0) shifted to image center R_inside = 270 - int(patch_h * np.sqrt(2.0) / 2.0) #radius is 270 (from DRIVE db docs), minus the patch diagonal (assumed it is a square #this is the limit to contain the full patch in the FOV radius = np.sqrt((x_*x_)+(y_*y_)) if radius < R_inside: return True else: return False #Divide all the full_imgs in pacthes def extract_ordered(full_imgs, patch_h, patch_w): assert (len(full_imgs.shape)==4) #4D arrays assert (full_imgs.shape[1]==1 or full_imgs.shape[1]==3) #check the channel is 1 or 3 img_h = full_imgs.shape[2] #height of the full image img_w = full_imgs.shape[3] #width of the full image N_patches_h = int(img_h/patch_h) #round to lowest int if (img_h%patch_h != 0): print ("warning: " +str(N_patches_h) +" patches in height, with about " +str(img_h%patch_h) +" pixels left over") N_patches_w = int(img_w/patch_w) #round to lowest int if (img_h%patch_h != 0): print ("warning: " +str(N_patches_w) +" patches in width, with about " +str(img_w%patch_w) +" pixels left over") print ("number of patches per image: " +str(N_patches_h*N_patches_w)) N_patches_tot = (N_patches_h*N_patches_w)*full_imgs.shape[0] patches = np.empty((N_patches_tot,full_imgs.shape[1],patch_h,patch_w)) iter_tot = 0 #iter over the total number of patches (N_patches) for i in range(full_imgs.shape[0]): #loop over the full images for h in range(N_patches_h): for w in range(N_patches_w): patch = full_imgs[i,:,h*patch_h:(h*patch_h)+patch_h,w*patch_w:(w*patch_w)+patch_w] patches[iter_tot]=patch iter_tot +=1 #total assert (iter_tot==N_patches_tot) return patches #array with all the full_imgs divided in patches def paint_border_overlap(full_imgs, patch_h, patch_w, stride_h, stride_w): assert (len(full_imgs.shape)==4) #4D arrays assert (full_imgs.shape[1]==1 or full_imgs.shape[1]==3) #check the channel is 1 or 3 img_h = full_imgs.shape[2] #height of the full image img_w = full_imgs.shape[3] #width of the full image leftover_h = (img_h-patch_h)%stride_h #leftover on the h dim leftover_w = (img_w-patch_w)%stride_w #leftover on the w dim if (leftover_h != 0): #change dimension of img_h print ("\nthe side H is not compatible with the selected stride of " +str(stride_h)) print ("img_h " +str(img_h) + ", patch_h " +str(patch_h) + ", stride_h " +str(stride_h)) print ("(img_h - patch_h) MOD stride_h: " +str(leftover_h)) print ("So the H dim will be padded with additional " +str(stride_h - leftover_h) + " pixels") tmp_full_imgs = np.zeros((full_imgs.shape[0],full_imgs.shape[1],img_h+(stride_h-leftover_h),img_w)) tmp_full_imgs[0:full_imgs.shape[0],0:full_imgs.shape[1],0:img_h,0:img_w] = full_imgs full_imgs = tmp_full_imgs if (leftover_w != 0): #change dimension of img_w print ("the side W is not compatible with the selected stride of " +str(stride_w)) print ("img_w " +str(img_w) + ", patch_w " +str(patch_w) + ", stride_w " +str(stride_w)) print ("(img_w - patch_w) MOD stride_w: " +str(leftover_w)) print ("So the W dim will be padded with additional " +str(stride_w - leftover_w) + " pixels") tmp_full_imgs = np.zeros((full_imgs.shape[0],full_imgs.shape[1],full_imgs.shape[2],img_w+(stride_w - leftover_w))) tmp_full_imgs[0:full_imgs.shape[0],0:full_imgs.shape[1],0:full_imgs.shape[2],0:img_w] = full_imgs full_imgs = tmp_full_imgs print ("new full images shape: \n" +str(full_imgs.shape)) return full_imgs #Divide all the full_imgs in pacthes def extract_ordered_overlap(full_imgs, patch_h, patch_w,stride_h,stride_w): assert (len(full_imgs.shape)==4) #4D arrays assert (full_imgs.shape[1]==1 or full_imgs.shape[1]==3) #check the channel is 1 or 3 img_h = full_imgs.shape[2] #height of the full image img_w = full_imgs.shape[3] #width of the full image assert ((img_h-patch_h)%stride_h==0 and (img_w-patch_w)%stride_w==0) N_patches_img = ((img_h-patch_h)//stride_h+1)*((img_w-patch_w)//stride_w+1) #// --> division between integers N_patches_tot = N_patches_img*full_imgs.shape[0] print ("Number of patches on h : " +str(((img_h-patch_h)//stride_h+1))) print ("Number of patches on w : " +str(((img_w-patch_w)//stride_w+1))) print ("number of patches per image: " +str(N_patches_img) +", totally for this dataset: " +str(N_patches_tot)) patches = np.empty((N_patches_tot,full_imgs.shape[1],patch_h,patch_w)) iter_tot = 0 #iter over the total number of patches (N_patches) for i in range(full_imgs.shape[0]): #loop over the full images for h in range((img_h-patch_h)//stride_h+1): for w in range((img_w-patch_w)//stride_w+1): patch = full_imgs[i,:,h*stride_h:(h*stride_h)+patch_h,w*stride_w:(w*stride_w)+patch_w] patches[iter_tot]=patch iter_tot +=1 #total assert (iter_tot==N_patches_tot) return patches #array with all the full_imgs divided in patches def recompone_overlap(preds, img_h, img_w, stride_h, stride_w): assert (len(preds.shape)==4) #4D arrays assert (preds.shape[1]==1 or preds.shape[1]==3) #check the channel is 1 or 3 patch_h = preds.shape[2] patch_w = preds.shape[3] N_patches_h = (img_h-patch_h)//stride_h+1 N_patches_w = (img_w-patch_w)//stride_w+1 N_patches_img = N_patches_h * N_patches_w print ("N_patches_h: " +str(N_patches_h)) print ("N_patches_w: " +str(N_patches_w)) print ("N_patches_img: " +str(N_patches_img)) assert (preds.shape[0]%N_patches_img==0) N_full_imgs = preds.shape[0]//N_patches_img print ("According to the dimension inserted, there are " +str(N_full_imgs) +" full images (of " +str(img_h)+"x" +str(img_w) +" each)") full_prob = np.zeros((N_full_imgs,preds.shape[1],img_h,img_w)) #itialize to zero mega array with sum of Probabilities full_sum = np.zeros((N_full_imgs,preds.shape[1],img_h,img_w)) k = 0 #iterator over all the patches for i in range(N_full_imgs): for h in range((img_h-patch_h)//stride_h+1): for w in range((img_w-patch_w)//stride_w+1): full_prob[i,:,h*stride_h:(h*stride_h)+patch_h,w*stride_w:(w*stride_w)+patch_w]+=preds[k] full_sum[i,:,h*stride_h:(h*stride_h)+patch_h,w*stride_w:(w*stride_w)+patch_w]+=1 k+=1 assert(k==preds.shape[0]) assert(np.min(full_sum)>=1.0) #at least one final_avg = full_prob/full_sum print (final_avg.shape) assert(np.max(final_avg)<=1.0) #max value for a pixel is 1.0 assert(np.min(final_avg)>=0.0) #min value for a pixel is 0.0 return final_avg #Recompone the full images with the patches def recompone(data,N_h,N_w): assert (data.shape[1]==1 or data.shape[1]==3) #check the channel is 1 or 3 assert(len(data.shape)==4) N_pacth_per_img = N_w*N_h assert(data.shape[0]%N_pacth_per_img == 0) N_full_imgs = data.shape[0]/N_pacth_per_img patch_h = data.shape[2] patch_w = data.shape[3] N_pacth_per_img = N_w*N_h #define and start full recompone full_recomp = np.empty((N_full_imgs,data.shape[1],N_h*patch_h,N_w*patch_w)) k = 0 #iter full img s = 0 #iter single patch while (s<data.shape[0]): #recompone one: single_recon = np.empty((data.shape[1],N_h*patch_h,N_w*patch_w)) for h in range(N_h): for w in range(N_w): single_recon[:,h*patch_h:(h*patch_h)+patch_h,w*patch_w:(w*patch_w)+patch_w]=data[s] s+=1 full_recomp[k]=single_recon k+=1 assert (k==N_full_imgs) return full_recomp #Extend the full images because patch divison is not exact def paint_border(data,patch_h,patch_w): assert (len(data.shape)==4) #4D arrays assert (data.shape[1]==1 or data.shape[1]==3) #check the channel is 1 or 3 img_h=data.shape[2] img_w=data.shape[3] new_img_h = 0 new_img_w = 0 if (img_h%patch_h)==0: new_img_h = img_h else: new_img_h = ((int(img_h)/int(patch_h))+1)*patch_h if (img_w%patch_w)==0: new_img_w = img_w else: new_img_w = ((int(img_w)/int(patch_w))+1)*patch_w new_data = np.zeros((data.shape[0],data.shape[1],new_img_h,new_img_w)) new_data[:,:,0:img_h,0:img_w] = data[:,:,:,:] return new_data #return only the pixels contained in the FOV, for both images and masks def pred_only_FOV(data_imgs,data_masks,original_imgs_border_masks): assert (len(data_imgs.shape)==4 and len(data_masks.shape)==4) #4D arrays assert (data_imgs.shape[0]==data_masks.shape[0]) assert (data_imgs.shape[2]==data_masks.shape[2]) assert (data_imgs.shape[3]==data_masks.shape[3]) assert (data_imgs.shape[1]==1 and data_masks.shape[1]==1) #check the channel is 1 height = data_imgs.shape[2] width = data_imgs.shape[3] new_pred_imgs = [] new_pred_masks = [] for i in range(data_imgs.shape[0]): #loop over the full images for x in range(width): for y in range(height): if inside_FOV_DRIVE(i,x,y,original_imgs_border_masks)==True: new_pred_imgs.append(data_imgs[i,:,y,x]) new_pred_masks.append(data_masks[i,:,y,x]) new_pred_imgs = np.asarray(new_pred_imgs) new_pred_masks = np.asarray(new_pred_masks) return new_pred_imgs, new_pred_masks #function to set to black everything outside the FOV, in a full image def kill_border(data, original_imgs_border_masks): assert (len(data.shape)==4) #4D arrays assert (data.shape[1]==1 or data.shape[1]==3) #check the channel is 1 or 3 height = data.shape[2] width = data.shape[3] for i in range(data.shape[0]): #loop over the full images for x in range(width): for y in range(height): if inside_FOV_DRIVE(i,x,y,original_imgs_border_masks)==False: data[i,:,y,x]=0.0 def inside_FOV_DRIVE(i, x, y, DRIVE_masks): assert (len(DRIVE_masks.shape)==4) #4D arrays assert (DRIVE_masks.shape[1]==1) #DRIVE masks is black and white # DRIVE_masks = DRIVE_masks/255. #NOOO!! otherwise with float numbers takes forever!! if (x >= DRIVE_masks.shape[3] or y >= DRIVE_masks.shape[2]): #my image bigger than the original return False if (DRIVE_masks[i,0,y,x]>0): #0==black pixels # print DRIVE_masks[i,0,y,x] #verify it is working right return True else: return False

 

posted @ 2018-05-04 16:49  fourmii  阅读(3090)  评论(0编辑  收藏  举报