机器学习工具代码
(持续整理)
数组阈值处理
"""
img 为图像数组,同时也是numpy数组
将img数据小于min的都设为min,同时将大于max的都设为max
"""
img[np.where(img < min)] = min
img[np.where(img > 250)] = max
归一化和中心化
min = np.min(img)
max = np.max(img)
center = (min + max) / 2
img = (img - center) /(max - min) * 2
最大联通域
from skimage import measure
def max_connected_domain_3D(arr):
# 取相同数字的最大连通域
labels = measure.label(arr) # <1.2s
t = np.bincount(labels.flatten())[1:] # <1.5s
max_pixel = np.argmax(t) + 1 # 位置变了,去除了0
labels[labels != max_pixel] = 0
labels[labels == max_pixel] = 1
return labels.astype(np.uint8)
# 测试
arr = [[1, 1, 0, 3], [1, 0, 3, 3], [0, 1, 3, 3], [0, 0, 0, 0]]
arr = np.asarray(arr)
print(arr)
print(max_connected_domain_3D(arr))
\[1 1 0 3\\
1 0 3 3\\
0 1 3 3\\
0 0 0 0\\
\]
\[\Downarrow
\]
\[0 0 0 1\\
0 0 1 1\\
0 0 1 1\\
0 0 0 0
\]
arr = np.squeeze(arr) # 从数组的形状中删除单维度条目,即把shape中为1的维度去掉
y = np.transpose(y,(1,2,0)) # 将数组的轴交换 (0, 1, 2) => (1, 2, 0)
"""
出处为写nrrd文件的时候,可以考虑nrrd的数组存储形式与正常数组维度不一致
"""
绘制模型
from keras.utils import plot_model
plot_model(model, "RUnet.png", True)
demo
from keras import models
from keras import layers
from keras import regularizers
from keras.utils import plot_model
def get_model(x, y, z):
model = models.Sequential()
model.add(layers.Conv3D(16, (3, 3, 2), activation='relu', input_shape=(x, y, z, 1)))
model.add(layers.BatchNormalization())
model.add(layers.Conv3D(8, (3, 3, 2), activation='relu', kernel_regularizer=regularizers.l2(0.1)))
model.add(layers.BatchNormalization())
model.add(layers.Conv3D(8, (3, 3, 2), activation='relu', kernel_regularizer=regularizers.l2(0.1)))
model.add(layers.BatchNormalization())
model.add(layers.Conv3D(8, (3, 3, 1), activation='relu', kernel_regularizer=regularizers.l2(0.1)))
model.add(layers.Dropout(rate=0.1))
model.add(layers.BatchNormalization())
model.add(layers.Flatten())
model.add(layers.BatchNormalization())
model.add(layers.Dense(13, activation='relu'))
model.add(layers.BatchNormalization())
model.add(layers.Dense(8, activation='relu'))
model.add(layers.BatchNormalization())
model.add(layers.Dense(8, activation='relu'))
model.add(layers.Dense(2, activation='sigmoid'))
model.summary()
return model
if __name__ == '__main__':
model = get_model(125, 125, 10)
plot_model(model, r"C:\Users\fan\Desktop\model.png", True)
效果图
注:需要安装graphviz
数据混淆
def data_confusion(data, label):
# 进行数据混淆
permutation = np.random.permutation(label.shape[0])
shuffled_data = data[permutation, :, :]
shuffled_label = label[permutation]
return shuffled_data, shuffled_label