run_and_display->ptp_utils.text2image_ldm_stable->diffusion_step->controller.step_callback->AttentionControlEdit.local_blend->class LocalBlend

注意力图实现方法

  1. self.step_store[key].append(attn_copy)#将每一步得到的注意力图存入到一个暂时的字段中
  2. self.attention_store[key][i] += self.step_store[key][i]#一个总的字典逐渐累加当前这一步的注意力图
  3. attention_store.get_average_attention()#字典中的每一个注意力图除以总的步数归一化,得到unet各个位置的各步数的注意力图的平均值
  4. 将unet各个位置的注意力图归一化,torch.Size([1000, 32, 32, 77])-》[32, 32, 77]
  5. 将张量归一化再转为灰度再转为图片

元素替换实现方法

  1. inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
    两段prompt中不同的单词的位置
  2. get_word_inds(x, i, tokenizer):
    获取不同单词的位置
  3. get_replacement_mapper_:
    返回一个(77, 77)的对角阵
  4. 在replace里mapper似乎没用

refine实现方法

为什么无法生成图片

生成latent
生成embedding
经过unet扩散
vae生成

  1. 测试只修改vae
    能生成模糊的人像
    image

  2. 加入latents_input = model.scheduler.scale_model_input(latents_input, t)
    直接变灰
    image

  3. 修改初始latent生成方式
    变成花花绿绿的
    image

run_and_display

run_and_display(prompts, controller, latent=None, run_baseline=False, generator=None)

ptp_utils.text2image_ldm_stable

进行整个扩散的过程

def text2image_ldm_stable(
    model,
    prompt: List[str],
    controller,
    num_inference_steps: int = 50,
    guidance_scale: float = 7.5,
    generator: Optional[torch.Generator] = None,
    latent: Optional[torch.FloatTensor] = None,
    low_resource: bool = False,
):

ptp_utils.register_attention_control

register_attention_control(model, controller)

ptp_utils.diffusion_step

单步扩散

diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False)
controller.step_callback
step_callback(self, x_t)
AttentionControlEdit.local_blend
__call__(self, x_t, attention_store)

如何理解cross_replace_steps和self_replace_steps

self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps,tokenizer).to(device)
if type(self_replace_steps) is float:
	self_replace_steps = 0, self_replace_steps
self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])

如何理解ptp_utils.get_time_words_attention_alpha

def get_time_words_attention_alpha(prompts, num_steps,
                                   cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
                                   tokenizer, max_num_words=77):
    if type(cross_replace_steps) is not dict:
        cross_replace_steps = {"default_": cross_replace_steps}
    if "default_" not in cross_replace_steps:
        cross_replace_steps["default_"] = (0., 1.)
    alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
    for i in range(len(prompts) - 1):
        alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
                                                  i)
    for key, item in cross_replace_steps.items():
        if key != "default_":
             inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
             for i, ind in enumerate(inds):
                 if len(ind) > 0:
                    alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
    alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
    return alpha_time_words
posted on 2023-11-20 11:10  FrostyForest  阅读(129)  评论(0编辑  收藏  举报