【AIGC第二十二篇】DragGAN:基于点拖拽进行图像编辑的GAN模型

DragGAN是由马普所提出的一种基于StyleGAN2的交互式图像编辑方法。本文将详细解读DragGAN,并通过代码进行说明。

DragGAN允许用户对生成的图像进行拖拽编辑,即通过定义多个控制点和对应的目标点,然后将控制点的内容拖拽到目标点。此外,用户还可以使用掩码来指示可编辑区域。DragGAN的引入为图像编辑领域带来了新的思路和方法,并具有广泛的应用前景。

 
 
 

1 总体思路

图1展示了DragGAN的处理流程,它以优化方式进行图像编辑操作,由两个子步骤组成:运动监督和点跟踪。

 

Fig.1 Overview of DragGAN Pipeline

第一个步骤是运动监督,其目标是实现图像中的移动效果。为了达到这个目标,DragGAN设计了一个目标函数,用于优化生成对抗网络(GAN)的潜码,以强制控制点向目标点移动。通过进行一次优化步骤,可以得到一个新的潜码和对应的新图像,其中的对象会发生轻微的移动。这个目标函数的关键在于确保移动后生成网络的深度特征在控制点附近邻域保持不变,对局部的特征监督会得到连续变化的潜码。

然而,运动监督步骤仅将每个控制点向其目标移动一小步,具体步长尚不确定,因为它受到复杂的优化动态的影响,因此对于不同的对象和部位会有所差异。

第二个步骤是点跟踪,其目的是解决目标移动后无法准确跟踪控制点的问题,这可能导致在下一个运动监督步骤中对错误的点进行监督,从而产生错误的结果。点跟踪步骤的重要性在于确保控制点的准确性,从而实现更精确的运动监督效果。

在点跟踪完成后,根据新的控制点和潜码,重复上述优化步骤,通常需要进行30-200次迭代。

2 具体方法

Fig.2 Method of DragGAN

2.1 运动监督

如图2所示,要将控制点 pi 移动到目标点 ti ,可以通过监督 pi (红色圆圈)周围的一小块向 ti 每次移动一小步(蓝色圆圈),然后不断迭代来完成,运动监督目标函数可定义为:

L=∑i=0n∑qi∈Ω1(pi,r1)‖sg(F(qi))−F(qi+di)‖1+λ‖(F−F0)⋅(1−M)‖1

其中, Ω1(pi,r1) 表示到 pi 距离小于 r1 的像素空间, F(q) 表示像素位置 q 处的特征向量, di=ti−pi‖ti−pi‖2 是从 pi 指向 ti 的归一化向量, F0 是与初始图像对应的特征图。为了平衡StyleGAN2生成网络的特征图的分辨率和判别能力,这里选择对生成网络第6个Block之后的特征进行运动监督。为了确保其与图像具有相同的分辨率,可以使用双线性插值来调整特征图的大小。

目标函数第一项约束移动后控制点附近邻域的特征响应保持不变。其中, sg(F(qi)) 是为了防止 qi+di 向 qi 移动,只允许 qi 向 qi+di 移动。但由于 qi+di 可能不是整数,这里需要通过双线性插值计算 F(qi+di) 。第二项约束掩码区域之外的特征图保持不变。

关于优化参数的选择,这里补充一些StyleGAN2的知识点:StyleGAN2引入了一种名为W+空间的特殊潜码表示方式,相比传统的Z空间,W+空间中的潜码w具有更高的维度,并与生成器网络的层次结构相对应,使得在W+空间中进行优化更容易控制图像的各种特征和风格。因此,相较于传统的Z空间,W+空间在图像编辑方面具有更强的能力,W+空间的优化更能保持生成图像的一致性和连续性。

然而,由于W+空间的维度较高,优化过程中潜码w可能会偏离原始分布,导致生成的图像与期望结果之间存在较大差距,这可能会对图像编辑的质量和准确性产生影响。

通过实验证明,图像的空间结构属性主要受到W+前6层参数的影响,而其他参数仅对外观产生影响。基于这一观察,我们选择只更新W+前6层参数,将其他层保持固定,以保持外观不变但改变结构。这种优化方法会导致图像内容发生轻微移动,以达到所期望的编辑效果。

这里有个疑问:基于点拖拽进行图像编辑,只让控制点附近的特征保持不变,反过来找到的潜码,为什么生成新图像,却能够整体移动,并且还能保持一致性和连贯性,而不只是控制点附近像素移动。
def motion_supervison(handle_points, target_points, F, r1, device):
    loss = 0
    n = len(handle_points)
    for i in range(n):
        target2handle = target_points[i] - handle_points[i]
        d_i = target2handle / (torch.norm(target2handle) + 1e-7) # 归一化方向向量
        if torch.norm(d_i) > torch.norm(target2handle):
            d_i = target2handle
        # 指定待修改处的圆形掩码
        mask = utils.create_circular_mask(
            F.shape[2], F.shape[3], center=handle_points[i].tolist(), radius=r1
        ).to(device)
        coordinates = torch.nonzero(mask).float()  # shape [num_points, 2]
        
        # Shift the coordinates in the direction d_i
        shifted_coordinates = coordinates + d_i[None]
        h, w = F.shape[2], F.shape[3]

        # Extract features in the mask region and compute the loss
        F_qi = F[:, :, mask]  # shape: [C, H*W]

        # Sample shifted patch from F
        normalized_shifted_coordinates = shifted_coordinates.clone()
        normalized_shifted_coordinates[:, 0] = (2.0 * shifted_coordinates[:, 0] / (h - 1)) - 1  # for height
        normalized_shifted_coordinates[:, 1] = (2.0 * shifted_coordinates[:, 1] / (w - 1)) - 1  # for width
        # Add extra dimensions for batch and channels (required by grid_sample)
        normalized_shifted_coordinates = normalized_shifted_coordinates.unsqueeze(0).unsqueeze(0)  # shape [1, 1, num_points, 2]
        normalized_shifted_coordinates = normalized_shifted_coordinates.flip(-1)  # grid_sample expects [x, y] instead of [y, x]
        normalized_shifted_coordinates = normalized_shifted_coordinates.clamp(-1, 1)

        # Use grid_sample to interpolate the feature map F at the shifted patch coordinates
        F_qi_plus_di = torch.nn.functional.grid_sample(
            F, normalized_shifted_coordinates, mode="bilinear", align_corners=True
        )
        # Output has shape [1, C, 1, num_points] so squeeze it
        F_qi_plus_di = F_qi_plus_di.squeeze(2)  # shape [1, C, num_points]

        loss += torch.nn.functional.l1_loss(F_qi.detach(), F_qi_plus_di)
    return loss



2.2 点跟踪

根据运动监督会产生新的潜码 ,然后生成新的图像 ,但是控制点移动位置可能并不精确,因此在每一次控制点移动后需要重新定位每个控制点 ,以便能够准确跟踪对象上的相应点。

由于GAN的特征具有很强的判别能力,所以无需应用复杂的跟踪方法,仅通过通过特征图的最近邻搜索就可以有效地进行点跟踪。

具体来说,将初始控制点 pi 的特征表示为 fi=F0(pi) ,将 pi 周围的块表示为:

Ω2(pi,r2)={(x,y)||x−xp,i|<r2,|y−yp,i∣<r2}

然后通过在 Ω2(pi,r2)中搜索 fi 的最近邻点来获得跟踪点:

pi:=arg⁡minqi∈Ω2(pi,r2)‖F′(qi)−fi‖1

为了更准确的进行点跟踪,可以将特征图双线性插值与图像相同的大小。

对于存在多个控制点的情况,对每个点应用同样的处理方法,接下来根据跟踪的点继续进行运动监督。

def point_tracking(
    F: torch.Tensor,
    F0: torch.Tensor,
    handle_points: torch.Tensor,
    handle_points0: torch.Tensor,
    r2: int = 3,
    device: torch.device = torch.device("cuda"),
) -> torch.Tensor:

    n = handle_points.shape[0]  # Number of handle points
    new_handle_points = torch.zeros_like(handle_points)

    for i in range(n):
        # Compute the patch around the handle point
        patch = utils.create_square_mask(
            F.shape[2], F.shape[3], center=handle_points[i].tolist(), radius=r2
        ).to(device)

        # Find indices where the patch is True
        patch_coordinates = torch.nonzero(patch)  # shape [num_points, 2]

        # Extract features in the patch
        F_qi = F[:, :, patch_coordinates[:, 0], patch_coordinates[:, 1]]
        # Extract feature of the initial handle point
        f_i = F0[:, :, handle_points0[i][0].long(), handle_points0[i][1].long()]

        # Compute the L1 distance between the patch features and the initial handle point feature
        distances = torch.norm(F_qi - f_i[:, :, None], p=1, dim=1)
        # 最近邻查找
        # Find the new handle point as the one with minimum distance
        min_index = torch.argmin(distances)
        new_handle_points[i] = patch_coordinates[min_index]

    return new_handle_points

参考文献

    • Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold
    • DragDiffusion: Harnessing Diffusion Models for Interactive Point-based Image Editing
 转:https://zhuanlan.zhihu.com/p/640748168?utm_id=0
posted @   rmticocean  阅读(133)  评论(0编辑  收藏  举报
编辑推荐:
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示