目录
run_and_display->ptp_utils.text2image_ldm_stable->diffusion_step->controller.step_callback->AttentionControlEdit.local_blend->class LocalBlend
注意力图实现方法
- self.step_store[key].append(attn_copy)#将每一步得到的注意力图存入到一个暂时的字段中
- self.attention_store[key][i] += self.step_store[key][i]#一个总的字典逐渐累加当前这一步的注意力图
- attention_store.get_average_attention()#字典中的每一个注意力图除以总的步数归一化,得到unet各个位置的各步数的注意力图的平均值
- 将unet各个位置的注意力图归一化,torch.Size([1000, 32, 32, 77])-》[32, 32, 77]
- 将张量归一化再转为灰度再转为图片
元素替换实现方法
- inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
两段prompt中不同的单词的位置 - get_word_inds(x, i, tokenizer):
获取不同单词的位置 - get_replacement_mapper_:
返回一个(77, 77)的对角阵 - 在replace里mapper似乎没用
refine实现方法
为什么无法生成图片
生成latent
生成embedding
经过unet扩散
vae生成
-
测试只修改vae
能生成模糊的人像
-
加入latents_input = model.scheduler.scale_model_input(latents_input, t)
直接变灰
-
修改初始latent生成方式
变成花花绿绿的
run_and_display
run_and_display(prompts, controller, latent=None, run_baseline=False, generator=None)
ptp_utils.text2image_ldm_stable
进行整个扩散的过程
def text2image_ldm_stable(
model,
prompt: List[str],
controller,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
generator: Optional[torch.Generator] = None,
latent: Optional[torch.FloatTensor] = None,
low_resource: bool = False,
):
ptp_utils.register_attention_control
register_attention_control(model, controller)
ptp_utils.diffusion_step
单步扩散
diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False)
controller.step_callback
step_callback(self, x_t)
AttentionControlEdit.local_blend
__call__(self, x_t, attention_store)
如何理解cross_replace_steps和self_replace_steps
self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps,tokenizer).to(device)
if type(self_replace_steps) is float:
self_replace_steps = 0, self_replace_steps
self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
如何理解ptp_utils.get_time_words_attention_alpha
def get_time_words_attention_alpha(prompts, num_steps,
cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
tokenizer, max_num_words=77):
if type(cross_replace_steps) is not dict:
cross_replace_steps = {"default_": cross_replace_steps}
if "default_" not in cross_replace_steps:
cross_replace_steps["default_"] = (0., 1.)
alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
for i in range(len(prompts) - 1):
alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
i)
for key, item in cross_replace_steps.items():
if key != "default_":
inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
for i, ind in enumerate(inds):
if len(ind) > 0:
alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
return alpha_time_words