Reference only code reading

Reference only code reading 代码逻辑梳理

Analyse reference-only code in controlnet extension of sd webui.

Control Net Hook

The entry point inner controlnet_main_entry looks like:

# def controlnet_main_entry():
self.latest_network = UnetHook(lowvram=is_low_vram)
self.latest_network.hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params, process=p,
                                 batch_option_uint_separate=batch_option_uint_separate,
                                 batch_option_style_align=batch_option_style_align)

What happened inner hook function?

The core hook part:

def hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params, process=p, ...):
    # ...
	model._original_forward = model.forward
	outer.original_forward = model.forward
	model.forward = forward_webui.__get__(model, UNetModel)
  • model is original unet model
  • outer is the new created UnetHook object itself

So the origin unet model’s forward is hijacked by forward_webui and the original forward is saved as UNetHook.original_forward.

So for other controlnet, the control module is running under hooked forward function, saved as total_controlnet_embedding and added to original result during U-Net Middle and Decoder blocks. (Not in Decoder Blocks)

Hack stuff todo with reference-only

Hook

But for reference part, there’s no control data forward process but a new hooked basic transform block in hook function:

# def hook(...):
all_modules = torch_dfs(model)

if need_attention_hijack:
		attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock) or isinstance(module, BasicTransformerBlockSGM)]
		attn_modules = sorted(attn_modules, key=lambda x: - x.norm1.normalized_shape[0])
		
		for i, module in enumerate(attn_modules):
		    if getattr(module, '_original_inner_forward_cn_hijack', None) is None:
		        module._original_inner_forward_cn_hijack = module._forward
		    module._forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
		    module.bank = []
		    module.style_cfgs = []
		    module.attn_weight = float(i) / float(len(attn_modules))

Hacked basic transformer forward

Hack the basic transformer module by a function hacked_basic_transformer_inner_forward .

What dose this function do?

self_attention_context = x_norm1
if outer.attention_auto_machine == AutoMachine.Write:
    if outer.attention_auto_machine_weight > self.attn_weight:
        self.bank.append(self_attention_context.detach().clone())
        self.style_cfgs.append(outer.current_style_fidelity)
if outer.attention_auto_machine == AutoMachine.Read:
    if len(self.bank) > 0:
        style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs))
        self_attn1_uc = self.attn1(x_norm1, context=torch.cat([self_attention_context] + self.bank, dim=1))
        self_attn1_c = self_attn1_uc.clone()
        if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5:
            self_attn1_c[outer.current_uc_indices] = self.attn1(
                x_norm1[outer.current_uc_indices],
                context=self_attention_context[outer.current_uc_indices])
        self_attn1 = style_cfg * self_attn1_c + (1.0 - style_cfg) * self_attn1_uc
    self.bank = []
    self.style_cfgs = []

First it only hack self-attention part, for cross-attention module is as-is. And there’re two status: Read/Write.

  • Write: Save the x_norm1 to self.bank
  • Read: Cat current self_attention_context and value in self.bank (previous x_norm1) and use them as the context for running attn1. Also, handle the unconditioned part accordingly

Hacked forward

So what's inside the main hijacked forward function related to reference-only?

Get ref_xt from used_hint_cond_latent (the latent control image) and (the latent control image) and then :

  1. Set automachine Write and calling outer.original_forward() with x=ref_xt that means using latent control image to forward and save the result to self.bank
  2. Set automachine Read and run the normal UNet process with real input x

Firstly save controlled context to self.bank and secondly add together as the context for real forward attn1.

As you can see: self_attn1_uc = self.attn1(x_norm1, context=torch.cat([self_attention_context] + self.bank, dim=1))

posted @ 2024-01-31 18:03  皮斯卡略夫  阅读(35)  评论(0编辑  收藏  举报