图像处理3 Felzenszwalb算法的Python实现
介绍
算法介绍上一篇随笔中很详细。
实现和效果
# coding:utf8 import cv2 import numpy as np from skimage import io as sio from skimage.segmentation import felzenszwalb import matplotlib.pyplot as plt from _felzenszwalb_cy import _felzenszwalb_cython def felzenszwalb_test(img,sigma,kernel,k, min_size): # 先使用纹理特征滤波,再计算距离 img = np.asanyarray(img, dtype=np.float) / 255 # rescale scale to behave like in reference implementation k = float(k) / 255. img = cv2.GaussianBlur(img, (kernel, kernel), sigma) height, width = img.shape[:2] num = height * width edges = np.zeros(((height - 1) * width * 2 + height * (width - 1) * 2, 3)) # 使用RGB距离,计算四邻域 index = np.array([i for i in range(height * width)]) index = index.reshape((height, width)) to_left = np.sqrt(((img[:, 1:] - img[:, :-1]) ** 2).sum(axis=2)) to_right = to_left to_up = np.sqrt(((img[1:] - img[:-1]) ** 2).sum(axis=2)) to_down = to_up last, cur = 0, 0 last, cur = cur, cur + (width - 1) * height edges[last: cur, 0] = index[:, 1:].reshape(-1) edges[last: cur, 1] = index[:, :-1].reshape(-1) edges[last: cur, 2] = to_left.reshape(-1) last, cur = cur, cur + (width - 1) * height edges[last: cur, 0] = index[:, :-1].reshape(-1) edges[last: cur, 1] = index[:, 1:].reshape(-1) edges[last: cur, 2] = to_right.reshape(-1) last, cur = cur, cur + (height - 1) * width edges[last: cur, 0] = index[1:].reshape(-1) edges[last: cur, 1] = index[:-1].reshape(-1) edges[last: cur, 2] = to_up.reshape(-1) last, cur = cur, cur + (height - 1) * width edges[last: cur, 0] = index[:-1].reshape(-1) edges[last: cur, 1] = index[1:].reshape(-1) edges[last: cur, 2] = to_down.reshape(-1) # 将边按照不相似度从小到大排序 edges = [edges[i] for i in range(edges.shape[0])] edges.sort(key=lambda x: x[2]) # 构建无向图(树) class universe(object): def __init__(self, n, k): self.f = np.array([i for i in range(n)]) # 树 self.r = np.zeros_like(self.f) # root self.s = np.ones((n)) # 存储像素点的个数 self.t = np.ones((n)) * k # 存储不相似度 self.k = k def find(self, x): # Find root of node x if x == self.f[x]: return x return self.find(self.f[x]) def join(self, a, b): # Join two trees containing nodes n and m if self.r[a] > self.r[b]: self.f[b] = a self.s[a] += self.s[b] else: self.f[a] = b self.s[b] += self.s[a] if self.r[a] == self.r[b]: self.r[b] += 1 u = universe(num, k) for edge in edges: a, b = u.find(int(edge[0])), u.find(int(edge[1])) if ((a != b) and (edge[2] <= min(u.t[a], u.t[b]))): # 更新类标号:将的类a,b标号统一为的标号a。更新该类的不相似度阈值为:k / (u.s[a]+u.s[b]) u.join(a, b) a = u.find(a) u.t[a] = edge[2] + k / u.s[a] for edge in edges: a, b = u.find(int(edge[0])), u.find(int(edge[1])) if ((a != b) and ((u.s[a] < min_size) or u.s[b] < min_size)): # 分割后会有很多小区域,当区域像素点的个数小于min_size时,选择与其差异最小的区域合并 u.join(a, b) dst = np.zeros_like(img) def locate(index): return index // width, index % width avg_color = np.zeros((num, 3)) for i in range(num): f = u.find(i) x, y = locate(i) avg_color[f, :] += img[x, y, :] / u.s[f] for i in range(height): for j in range(width): f = u.find(i * width + j) dst[i, j, :] = avg_color[f, :] return dst if __name__ == '__main__': sigma = 0.5 kernel = 3 K, min_size = 500, 50 image = sio.imread("test_data/0010.jpg") # skimage自带的felzenszwalb算法 seg1 = felzenszwalb(image, scale=K, sigma=sigma, min_size=min_size) # skimage自带的felzenszwalb算法cython版转Python代码,更改了高斯模糊 seg2 = _felzenszwalb_cython(image, scale=K, sigma=sigma, kernel=kernel,min_size=min_size) # felzenszwalb算法的实现,相比于上一种,区别主要在四邻域和颜色还原 seg3=felzenszwalb_test(image, sigma, kernel,K, min_size) fig = plt.figure() a = fig.add_subplot(221) plt.imshow(image) a.set_title("image") a = fig.add_subplot(222) plt.imshow(seg1) a.set_title("seg1") a = fig.add_subplot(223) plt.imshow(seg2) a.set_title("seg2") a = fig.add_subplot(224) plt.imshow(seg3) a.set_title("seg3") plt.show()
以下是skimage自带的felzenszwalb算法cython版转Python代码,更改了高斯模糊。
Cython部分可以参考Python 调用 C/C++实现卷积中的介绍。
import numpy as np import cv2 def find_root(forest, n): """Find the root of node n. Given the example above, for any integer from 1 to 9, 1 is always returned """ root = n while (forest[root] < root): root = forest[root] return root def set_root(forest, n, root): """ Set all nodes on a path to point to new_root. Given the example above, given n=9, root=6, it would "reconnect" the tree. so forest[9] = 6 and forest[8] = 6 The ultimate goal is that all tree nodes point to the real root, which is element 1 in this case. """ while (forest[n] < n): j = forest[n] forest[n] = root n = j forest[n] = root def join_trees(forest, n, m): """Join two trees containing nodes n and m. If we imagine that in the example tree, the root 1 is not known, we rather have two disjoint trees with roots 2 and 6. Joining them would mean that all elements of both trees become connected to the element 2, so forest[9] == 2, forest[6] == 2 etc. However, when the relationship between 1 and 2 can still be discovered later. """ if (n != m): root = find_root(forest, n) root_m = find_root(forest, m) if (root > root_m): root = root_m set_root(forest, n, root) set_root(forest, m, root) def _felzenszwalb_cython(image, scale=1, sigma=0.8,kernel=3,min_size=20): """Felzenszwalb's efficient graph based segmentation for single or multiple channels. Produces an oversegmentation of a single or multi-channel image using a fast, minimum spanning tree based clustering on the image grid. The number of produced segments as well as their size can only be controlled indirectly through ``scale``. Segment size within an image can vary greatly depending on local contrast. Parameters ---------- image : (N, M, C) ndarray Input image. scale : float, optional (default 1) Sets the obervation level. Higher means larger clusters. sigma : float, optional (default 0.8) Width of Gaussian smoothing kernel used in preprocessing. Larger sigma gives smother segment boundaries. min_size : int, optional (default 20) Minimum component size. Enforced using postprocessing. Returns ------- segment_mask : (N, M) ndarray Integer mask indicating segment labels. """ # image = img_as_float(image) image = np.asanyarray(image, dtype=np.float) / 255 # rescale scale to behave like in reference implementation scale = float(scale) / 255. # image = ndi.gaussian_filter(image, sigma=[sigma, sigma, 0]) image =cv2.GaussianBlur(image, (kernel, kernel), sigma) # compute edge weights in 8 connectivity: down_cost = np.sqrt(np.sum((image[1:, :, :] - image[:-1, :, :]) *(image[1:, :, :] - image[:-1, :, :]), axis=-1)) right_cost = np.sqrt(np.sum((image[:, 1:, :] - image[:, :-1, :]) *(image[:, 1:, :] - image[:, :-1, :]), axis=-1)) dright_cost = np.sqrt(np.sum((image[1:, 1:, :] - image[:-1, :-1, :]) *(image[1:, 1:, :] - image[:-1, :-1, :]), axis=-1)) uright_cost = np.sqrt(np.sum((image[1:, :-1, :] - image[:-1, 1:, :]) *(image[1:, :-1, :] - image[:-1, 1:, :]), axis=-1)) costs = np.hstack([ right_cost.ravel(), down_cost.ravel(), dright_cost.ravel(), uright_cost.ravel()]).astype(np.float) # compute edges between pixels: height, width = image.shape[:2] segments = np.arange(width * height, dtype=np.intp).reshape(height, width) down_edges = np.c_[segments[1:, :].ravel(), segments[:-1, :].ravel()] right_edges = np.c_[segments[:, 1:].ravel(), segments[:, :-1].ravel()] dright_edges = np.c_[segments[1:, 1:].ravel(), segments[:-1, :-1].ravel()] uright_edges = np.c_[segments[:-1, 1:].ravel(), segments[1:, :-1].ravel()] edges = np.vstack([right_edges, down_edges, dright_edges, uright_edges]) # initialize data structures for segment size # and inner cost, then start greedy iteration over edges. edge_queue = np.argsort(costs) edges = np.ascontiguousarray(edges[edge_queue]) costs = np.ascontiguousarray(costs[edge_queue]) segments_p = np.arange(width * height, dtype=np.intp) #segments segment_size = np.ones(width * height, dtype=np.intp) # inner cost of segments cint = np.zeros(width * height) num_costs = costs.size for e in range(num_costs): seg0 = find_root(segments_p, edges[e][0]) seg1 = find_root(segments_p, edges[e][1]) if seg0 == seg1: continue inner_cost0 = cint[seg0] + scale / segment_size[seg0] inner_cost1 = cint[seg1] + scale / segment_size[seg1] if costs[e] < min(inner_cost0, inner_cost1): # update size and cost join_trees(segments_p, seg0, seg1) seg_new = find_root(segments_p, seg0) segment_size[seg_new] = segment_size[seg0] + segment_size[seg1] cint[seg_new] = costs[e] # postprocessing to remove small segments # edges = edges for e in range(num_costs): seg0 = find_root(segments_p, edges[e][0]) seg1 = find_root(segments_p, edges[e][1]) if seg0 == seg1: continue if segment_size[seg0] < min_size or segment_size[seg1] < min_size: join_trees(segments_p, seg0, seg1) seg_new = find_root(segments_p, seg0) segment_size[seg_new] = segment_size[seg0] + segment_size[seg1] # unravel the union find tree flat = segments_p.ravel() old = np.zeros_like(flat) while (old != flat).any(): old = flat flat = flat[flat] flat = np.unique(flat, return_inverse=True)[1] return flat.reshape((height, width))