图像处理3 Felzenszwalb算法的Python实现

介绍

算法介绍上一篇随笔中很详细。

图像处理2 基于图的图像分割算法

实现和效果

 

 

# coding:utf8
import cv2
import numpy as np
from skimage import io as sio
from skimage.segmentation import felzenszwalb
import matplotlib.pyplot as plt
from _felzenszwalb_cy import _felzenszwalb_cython


def felzenszwalb_test(img,sigma,kernel,k, min_size):
    # 先使用纹理特征滤波,再计算距离
    img = np.asanyarray(img, dtype=np.float) / 255

    # rescale scale to behave like in reference implementation
    k = float(k) / 255.
    img = cv2.GaussianBlur(img, (kernel, kernel), sigma)
    height, width = img.shape[:2]
    num = height * width
    edges = np.zeros(((height - 1) * width * 2 + height * (width - 1) * 2, 3))

    # 使用RGB距离,计算四邻域
    index = np.array([i for i in range(height * width)])
    index = index.reshape((height, width))
    to_left = np.sqrt(((img[:, 1:] - img[:, :-1]) ** 2).sum(axis=2))
    to_right = to_left
    to_up = np.sqrt(((img[1:] - img[:-1]) ** 2).sum(axis=2))
    to_down = to_up


    last, cur = 0, 0
    last, cur = cur, cur + (width - 1) * height
    edges[last: cur, 0] = index[:, 1:].reshape(-1)
    edges[last: cur, 1] = index[:, :-1].reshape(-1)
    edges[last: cur, 2] = to_left.reshape(-1)

    last, cur = cur, cur + (width - 1) * height
    edges[last: cur, 0] = index[:, :-1].reshape(-1)
    edges[last: cur, 1] = index[:, 1:].reshape(-1)
    edges[last: cur, 2] = to_right.reshape(-1)

    last, cur = cur, cur + (height - 1) * width
    edges[last: cur, 0] = index[1:].reshape(-1)
    edges[last: cur, 1] = index[:-1].reshape(-1)
    edges[last: cur, 2] = to_up.reshape(-1)

    last, cur = cur, cur + (height - 1) * width
    edges[last: cur, 0] = index[:-1].reshape(-1)
    edges[last: cur, 1] = index[1:].reshape(-1)
    edges[last: cur, 2] = to_down.reshape(-1)

    # 将边按照不相似度从小到大排序
    edges = [edges[i] for i in range(edges.shape[0])]
    edges.sort(key=lambda x: x[2])

    # 构建无向图(树)
    class universe(object):
        def __init__(self, n, k):
            self.f = np.array([i for i in range(n)])  #
            self.r = np.zeros_like(self.f)   # root
            self.s = np.ones((n))  # 存储像素点的个数
            self.t = np.ones((n)) * k  # 存储不相似度
            self.k = k

        def find(self, x):    # Find root of node x
            if x == self.f[x]:
                return x
            return self.find(self.f[x])

        def join(self, a, b):  # Join two trees containing nodes n and m
            if self.r[a] > self.r[b]:
                self.f[b] = a
                self.s[a] += self.s[b]
            else:
                self.f[a] = b
                self.s[b] += self.s[a]
                if self.r[a] == self.r[b]:
                    self.r[b] += 1

    u = universe(num, k)
    for edge in edges:
        a, b = u.find(int(edge[0])), u.find(int(edge[1]))
        if ((a != b) and (edge[2] <= min(u.t[a], u.t[b]))):
            # 更新类标号:将的类a,b标号统一为的标号a。更新该类的不相似度阈值为:k / (u.s[a]+u.s[b])
            u.join(a, b)
            a = u.find(a)
            u.t[a] = edge[2] + k / u.s[a]

    for edge in edges:
        a, b = u.find(int(edge[0])), u.find(int(edge[1]))
        if ((a != b) and ((u.s[a] < min_size) or u.s[b] < min_size)):
            # 分割后会有很多小区域,当区域像素点的个数小于min_size时,选择与其差异最小的区域合并
            u.join(a, b)

    dst = np.zeros_like(img)

    def locate(index):
        return index // width, index % width

    avg_color = np.zeros((num, 3))

    for i in range(num):
        f = u.find(i)
        x, y = locate(i)
        avg_color[f, :] += img[x, y, :] / u.s[f]

    for i in range(height):
        for j in range(width):
            f = u.find(i * width + j)
            dst[i, j, :] = avg_color[f, :]
    return dst


if __name__ == '__main__':
    sigma = 0.5
    kernel = 3
    K, min_size = 500, 50
    image = sio.imread("test_data/0010.jpg")
    # skimage自带的felzenszwalb算法
    seg1 = felzenszwalb(image, scale=K, sigma=sigma, min_size=min_size)
    # skimage自带的felzenszwalb算法cython版转Python代码,更改了高斯模糊
    seg2 = _felzenszwalb_cython(image, scale=K, sigma=sigma, kernel=kernel,min_size=min_size)
    # felzenszwalb算法的实现,相比于上一种,区别主要在四邻域和颜色还原
    seg3=felzenszwalb_test(image, sigma, kernel,K, min_size)

    fig = plt.figure()
    a = fig.add_subplot(221)
    plt.imshow(image)
    a.set_title("image")

    a = fig.add_subplot(222)
    plt.imshow(seg1)
    a.set_title("seg1")

    a = fig.add_subplot(223)
    plt.imshow(seg2)
    a.set_title("seg2")

    a = fig.add_subplot(224)
    plt.imshow(seg3)
    a.set_title("seg3")
    plt.show()

 

以下是skimage自带的felzenszwalb算法cython版转Python代码,更改了高斯模糊。

Cython部分可以参考Python 调用 C/C++实现卷积中的介绍。

 

import numpy as np
import cv2

def find_root(forest, n):
    """Find the root of node n.
    Given the example above, for any integer from 1 to 9, 1 is always returned
    """
    root = n
    while (forest[root] < root):
        root = forest[root]
    return root


def set_root(forest, n, root):
    """
    Set all nodes on a path to point to new_root.
    Given the example above, given n=9, root=6, it would "reconnect" the tree.
    so forest[9] = 6 and forest[8] = 6
    The ultimate goal is that all tree nodes point to the real root,
    which is element 1 in this case.
    """
    while (forest[n] < n):
        j = forest[n]
        forest[n] = root
        n = j
    forest[n] = root

def join_trees(forest, n, m):
    """Join two trees containing nodes n and m.
    If we imagine that in the example tree, the root 1 is not known, we
    rather have two disjoint trees with roots 2 and 6.
    Joining them would mean that all elements of both trees become connected
    to the element 2, so forest[9] == 2, forest[6] == 2 etc.
    However, when the relationship between 1 and 2 can still be discovered later.
    """

    if (n != m):
        root = find_root(forest, n)
        root_m = find_root(forest, m)

        if (root > root_m):
            root = root_m

        set_root(forest, n, root)
        set_root(forest, m, root)

def _felzenszwalb_cython(image, scale=1, sigma=0.8,kernel=3,min_size=20):
    """Felzenszwalb's efficient graph based segmentation for
    single or multiple channels.

    Produces an oversegmentation of a single or multi-channel image
    using a fast, minimum spanning tree based clustering on the image grid.
    The number of produced segments as well as their size can only be
    controlled indirectly through ``scale``. Segment size within an image can
    vary greatly depending on local contrast.

    Parameters
    ----------
    image : (N, M, C) ndarray
        Input image.
    scale : float, optional (default 1)
        Sets the obervation level. Higher means larger clusters.
    sigma : float, optional (default 0.8)
        Width of Gaussian smoothing kernel used in preprocessing.
        Larger sigma gives smother segment boundaries.
    min_size : int, optional (default 20)
        Minimum component size. Enforced using postprocessing.

    Returns
    -------
    segment_mask : (N, M) ndarray
        Integer mask indicating segment labels.
    """

    # image = img_as_float(image)
    image = np.asanyarray(image, dtype=np.float) / 255

    # rescale scale to behave like in reference implementation
    scale = float(scale) / 255.
    # image = ndi.gaussian_filter(image, sigma=[sigma, sigma, 0])
    image =cv2.GaussianBlur(image, (kernel, kernel), sigma)

    # compute edge weights in 8 connectivity:
    down_cost = np.sqrt(np.sum((image[1:, :, :] - image[:-1, :, :])
        *(image[1:, :, :] - image[:-1, :, :]), axis=-1))
    right_cost = np.sqrt(np.sum((image[:, 1:, :] - image[:, :-1, :])
        *(image[:, 1:, :] - image[:, :-1, :]), axis=-1))
    dright_cost = np.sqrt(np.sum((image[1:, 1:, :] - image[:-1, :-1, :])
    *(image[1:, 1:, :] - image[:-1, :-1, :]), axis=-1))
    uright_cost = np.sqrt(np.sum((image[1:, :-1, :] - image[:-1, 1:, :])
        *(image[1:, :-1, :] - image[:-1, 1:, :]), axis=-1))
    costs = np.hstack([
        right_cost.ravel(), down_cost.ravel(), dright_cost.ravel(),
        uright_cost.ravel()]).astype(np.float)

    # compute edges between pixels:
    height, width = image.shape[:2]
    segments = np.arange(width * height, dtype=np.intp).reshape(height, width)
    down_edges = np.c_[segments[1:, :].ravel(), segments[:-1, :].ravel()]
    right_edges = np.c_[segments[:, 1:].ravel(), segments[:, :-1].ravel()]
    dright_edges = np.c_[segments[1:, 1:].ravel(), segments[:-1, :-1].ravel()]
    uright_edges = np.c_[segments[:-1, 1:].ravel(), segments[1:, :-1].ravel()]
    edges = np.vstack([right_edges, down_edges, dright_edges, uright_edges])

    # initialize data structures for segment size
    # and inner cost, then start greedy iteration over edges.
    edge_queue = np.argsort(costs)
    edges = np.ascontiguousarray(edges[edge_queue])
    costs = np.ascontiguousarray(costs[edge_queue])
    segments_p = np.arange(width * height, dtype=np.intp) #segments

    segment_size = np.ones(width * height, dtype=np.intp)

    # inner cost of segments
    cint = np.zeros(width * height)
    num_costs = costs.size

    for e in range(num_costs):
        seg0 = find_root(segments_p, edges[e][0])
        seg1 = find_root(segments_p, edges[e][1])
        if seg0 == seg1:
            continue
        inner_cost0 = cint[seg0] + scale / segment_size[seg0]
        inner_cost1 = cint[seg1] + scale / segment_size[seg1]
        if costs[e] < min(inner_cost0, inner_cost1):
            # update size and cost
            join_trees(segments_p, seg0, seg1)
            seg_new = find_root(segments_p, seg0)
            segment_size[seg_new] = segment_size[seg0] + segment_size[seg1]
            cint[seg_new] = costs[e]

    # postprocessing to remove small segments
    # edges = edges
    for e in range(num_costs):
        seg0 = find_root(segments_p, edges[e][0])
        seg1 = find_root(segments_p, edges[e][1])
        if seg0 == seg1:
            continue
        if segment_size[seg0] < min_size or segment_size[seg1] < min_size:
            join_trees(segments_p, seg0, seg1)
            seg_new = find_root(segments_p, seg0)
            segment_size[seg_new] = segment_size[seg0] + segment_size[seg1]

    # unravel the union find tree
    flat = segments_p.ravel()
    old = np.zeros_like(flat)
    while (old != flat).any():
        old = flat
        flat = flat[flat]
    flat = np.unique(flat, return_inverse=True)[1]
    return flat.reshape((height, width))

 

posted on 2018-08-12 23:39  1357  阅读(7354)  评论(2编辑  收藏  举报

导航