Stable-Diffusion-webUI 代码阅读01 —— 从启动开始
Stable-Diffusion-webUI 代码阅读01 —— 从启动开始
由于实习工作需要,决定用几天时间阅读一遍stable-diffusion-webui的代码。
本文参考知乎专栏,并且添加了一些自己的理解,感谢大佬!知乎专栏:自动做游戏:AI技术落地于游戏开发 - 知乎 (zhihu.com)
最近工作主要侧重于OneFlow框架应用于SD的加速和不同Sampler的支持适配的工作,所以阅读代码也将其作为切入点。
由于本人刚刚入门,故许多内容比较粗糙,有问题希望多多批评指正!
webui本身更新较快,本文所阅读的内容为 AUTOMATIC1111大佬的1.4.1,这里是项目地址
webui,启动!—— webui.sh
stable-diffusion的启动:
bash webui.sh -f
这里添加的-f是因为,SD不允许root用户进行启动,可是公司租用的服务器的配置的却是root的权限,故加入-f,轻松解决,作为一个实习生,也不乱干扰服务器的配置。
检测系统与环境
# If run from macOS, load defaults from webui-macos-env.sh
if [[ "$OSTYPE" == "darwin"* ]]; then
if [[ -f webui-macos-env.sh ]]
then
source ./webui-macos-env.sh
fi
fi
检测当前系统是否为macOS,如果是,则检测并加载webui-macos-env.sh
如果环境不是macOS,那么
if [[ -f webui-user.sh ]]
then
source ./webui-user.sh
fi
检测并加载webui-user.sh
webui-macos-env.sh
#!/bin/bash
####################################################################
# macOS defaults #
# Please modify webui-user.sh to change these instead of this file #
####################################################################
if [[ -x "$(command -v python3.10)" ]]
then
python_cmd="python3.10"
fi
export install_dir="$HOME"
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2"
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
export PYTORCH_ENABLE_MPS_FALLBACK=1
####################################################################
可以看到,要求修改webui.sh
默认的一些变量配置包括:
-
python环境
-
install_dir的默认路径
-
COMMANDLINE_ARGS 额外的命令行参数
--skip-torch-cuda-test
如果是nvidia显卡,不需要加这个参数--upcast-sampling
向上采样。通常产生与--no-half
(不将模型切换为16位浮点数)相同的效果,在较少的内存下效果更好--no-half-vae
不将VAE模型切换为16位浮点数--use-cpu interrogate
使用CPU
-
TORCH环境
-
K_Diffusion
- 仓库地址
- 具体的某次commit的哈希值
-
启用pytorch的MPS回退功能,MPS是CUDA提供的多进程服务,启用该功能可以提高GPU的利用率和计算性能
webui-user.sh
#!/bin/bash
#########################################################
# Uncomment and change the variables below to your need:#
#########################################################
# Install directory without trailing slash
#install_dir="/home/$(whoami)"
# Name of the subdirectory
#clone_dir="stable-diffusion-webui"
# Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention"
#export COMMANDLINE_ARGS=""
# python3 executable
#python_cmd="python3"
# git executable
#export GIT="git"
# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
#venv_dir="venv"
# script to launch to start the app
#export LAUNCH_SCRIPT="launch.py"
# install command for torch
#export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113"
# Requirements file to use for stable-diffusion-webui
#export REQS_FILE="requirements_versions.txt"
# Fixed git repos
#export K_DIFFUSION_PACKAGE=""
#export GFPGAN_PACKAGE=""
# Fixed git commits
#export STABLE_DIFFUSION_COMMIT_HASH=""
#export CODEFORMER_COMMIT_HASH=""
#export BLIP_COMMIT_HASH=""
# Uncomment to enable accelerated launch
#export ACCELERATE="True"
# Uncomment to disable TCMalloc
#export NO_TCMALLOC="True"
###########################################
可以看到这里全是注释掉的,用户设置一些参数,可以设置的比如python指令,git指令,虚拟环境地址等等。
变量默认值
# Set defaults
# Install directory without trailing slash
if [[ -z "${install_dir}" ]]
then
install_dir="$(pwd)"
fi
# Name of the subdirectory (defaults to stable-diffusion-webui)
if [[ -z "${clone_dir}" ]]
then
clone_dir="stable-diffusion-webui"
fi
# python3 executable
if [[ -z "${python_cmd}" ]]
then
python_cmd="python3"
fi
# git executable
if [[ -z "${GIT}" ]]
then
export GIT="git"
fi
# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
if [[ -z "${venv_dir}" ]]
then
venv_dir="venv"
fi
if [[ -z "${LAUNCH_SCRIPT}" ]]
then
LAUNCH_SCRIPT="launch.py"
fi
这里可以看到是一些关于安装目录,子文件夹名字,python指令,git指令,虚拟环境地址等。
整体内容大概相同,至于各个参数的详解,可以看b站大佬的这篇文章:最详细的WEBUI启动参数 - 哔哩哔哩 (bilibili.com)
root权限判断
# this script cannot be run as root by default
can_run_as_root=0
# read any command line flags to the webui.sh script
while getopts "f" flag > /dev/null 2>&1
do
case ${flag} in
f) can_run_as_root=1;;
*) break;;
esac
done
# Disable sentry logging
export ERROR_REPORTING=FALSE
# Do not reinstall existing pip packages on Debian/Ubuntu
export PIP_IGNORE_INSTALLED=0
# Pretty print
delimiter="################################################################"
printf "\n%s\n" "${delimiter}"
printf "\e[1m\e[32mInstall script for stable-diffusion + Web UI\n"
printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m"
printf "\n%s\n" "${delimiter}"
# Do not run as root
if [[ $(id -u) -eq 0 && can_run_as_root -eq 0 ]]
then
printf "\n%s\n" "${delimiter}"
printf "\e[1m\e[31mERROR: This script must not be launched as root, aborting...\e[0m"
printf "\n%s\n" "${delimiter}"
exit 1
else
printf "\n%s\n" "${delimiter}"
printf "Running on \e[1m\e[32m%s\e[0m user" "$(whoami)"
printf "\n%s\n" "${delimiter}"
fi
if [[ $(getconf LONG_BIT) = 32 ]]
then
printf "\n%s\n" "${delimiter}"
printf "\e[1m\e[31mERROR: Unsupported Running on a 32bit OS\e[0m"
printf "\n%s\n" "${delimiter}"
exit 1
fi
if [[ -d .git ]]
then
printf "\n%s\n" "${delimiter}"
printf "Repo already cloned, using it as install directory"
printf "\n%s\n" "${delimiter}"
install_dir="${PWD}/../"
clone_dir="${PWD##*/}"
fi
其实加入-f就可以了,获得id来判断是否为root用户,并且一次来判断是否能够继续运行程序
除了root用户以外,还对32位系统进行了限制。
获取gpu信息,获取git和python信息
# Check prerequisites
gpu_info=$(lspci 2>/dev/null | grep -E "VGA|Display")
case "$gpu_info" in
*"Navi 1"*)
export HSA_OVERRIDE_GFX_VERSION=10.3.0
if [[ -z "${TORCH_COMMAND}" ]]
then
pyv="$(${python_cmd} -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')"
if [[ $(bc <<< "$pyv <= 3.10") -eq 1 ]]
then
# Navi users will still use torch 1.13 because 2.0 does not seem to work.
export TORCH_COMMAND="pip install torch==1.13.1+rocm5.2 torchvision==0.14.1+rocm5.2 --index-url https://download.pytorch.org/whl/rocm5.2"
else
printf "\e[1m\e[31mERROR: RX 5000 series GPUs must be using at max python 3.10, aborting...\e[0m"
exit 1
fi
fi
;;
*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
;;
*"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
printf "\n%s\n" "${delimiter}"
printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"
printf "\n%s\n" "${delimiter}"
;;
*)
;;
esac
if ! echo "$gpu_info" | grep -q "NVIDIA";
then
if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
then
export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
fi
fi
for preq in "${GIT}" "${python_cmd}"
do
if ! hash "${preq}" &>/dev/null
then
printf "\n%s\n" "${delimiter}"
printf "\e[1m\e[31mERROR: %s is not installed, aborting...\e[0m" "${preq}"
printf "\n%s\n" "${delimiter}"
exit 1
fi
done
这部分没什么可说的,都是一些环境设置的架构级别的逻辑。暂时掠过
创建或激活python虚拟环境
if ! "${python_cmd}" -c "import venv" &>/dev/null
then
printf "\n%s\n" "${delimiter}"
printf "\e[1m\e[31mERROR: python3-venv is not installed, aborting...\e[0m"
printf "\n%s\n" "${delimiter}"
exit 1
fi
cd "${install_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/, aborting...\e[0m" "${install_dir}"; exit 1; }
if [[ -d "${clone_dir}" ]]
then
cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
else
printf "\n%s\n" "${delimiter}"
printf "Clone stable-diffusion-webui"
printf "\n%s\n" "${delimiter}"
"${GIT}" clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git "${clone_dir}"
cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
fi
if [[ -z "${VIRTUAL_ENV}" ]];
then
printf "\n%s\n" "${delimiter}"
printf "Create and activate python venv"
printf "\n%s\n" "${delimiter}"
cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
if [[ ! -d "${venv_dir}" ]]
then
"${python_cmd}" -m venv "${venv_dir}"
first_launch=1
fi
# shellcheck source=/dev/null
if [[ -f "${venv_dir}"/bin/activate ]]
then
source "${venv_dir}"/bin/activate
else
printf "\n%s\n" "${delimiter}"
printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m"
printf "\n%s\n" "${delimiter}"
exit 1
fi
else
printf "\n%s\n" "${delimiter}"
printf "python venv already activate: ${VIRTUAL_ENV}"
printf "\n%s\n" "${delimiter}"
fi
这里主要是检测是否有python虚拟环境,如果没有,那么就创建一个,可以看到包默认存储在./venv
的位置
Linux下的配置TCMALLOC
# Try using TCMalloc on Linux
prepare_tcmalloc() {
if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then
TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)"
if [[ ! -z "${TCMALLOC}" ]]; then
echo "Using TCMalloc: ${TCMALLOC}"
export LD_PRELOAD="${TCMALLOC}"
else
printf "\e[1m\e[31mCannot locate TCMalloc (improves CPU memory usage)\e[0m\n"
fi
fi
}
如果在Linux下设置了NO_TCMALLOC
和LD_PERLOAD
的环境变量,那么就会在/user/bin
目录下搜索libtcmalloc
相关的动态库,并且将名称赋值给LD_PRELOAD
,以便在后续过程中进行预加载动态库。
TCMalloc是一个在Linux系统中用来优化CPU内存使用的工具。
判断accelerate
KEEP_GOING=1
export SD_WEBUI_RESTART=tmp/restart
while [[ "$KEEP_GOING" -eq "1" ]]; do
if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ]; then
printf "\n%s\n" "${delimiter}"
printf "Accelerating launch.py..."
printf "\n%s\n" "${delimiter}"
prepare_tcmalloc
accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@"
else
printf "\n%s\n" "${delimiter}"
printf "Launching launch.py..."
printf "\n%s\n" "${delimiter}"
prepare_tcmalloc
"${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
fi
if [[ ! -f tmp/restart ]]; then
KEEP_GOING=0
fi
done
accelerate是Hugging Face发布的一个库,可以进行分布式运行,用来加速pytorch。
launch.py
# launch.py
from modules import launch_utils
args = launch_utils.args
python = launch_utils.python
git = launch_utils.git
index_url = launch_utils.index_url
dir_repos = launch_utils.dir_repos
commit_hash = launch_utils.commit_hash
git_tag = launch_utils.git_tag
run = launch_utils.run
is_installed = launch_utils.is_installed
repo_dir = launch_utils.repo_dir
run_pip = launch_utils.run_pip
check_run_python = launch_utils.check_run_python
git_clone = launch_utils.git_clone
git_pull_recursive = launch_utils.git_pull_recursive
run_extension_installer = launch_utils.run_extension_installer
prepare_environment = launch_utils.prepare_environment
configure_for_tests = launch_utils.configure_for_tests
start = launch_utils.start
def main():
if not args.skip_prepare_environment:
prepare_environment()
if args.test_server:
configure_for_tests()
start()
if __name__ == "__main__":
main()
终于到了比较熟悉的python了,前文中所指的LAUNCH_SCRIPT
默认指的就是这个文件,此代码main函数内主要完成三件事
- 环境准备
- 测试configure配置
- 启动
对于环境准备和configure配置的,可以重点去关注launch_utils.py
,由于技术路线更关注模型推理部分,故跳过
start
# modules/launch_utils.py
def start():
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
import webui
if '--nowebui' in sys.argv:
webui.api_only()
else:
webui.webui()
这里是launch_utils.py
,也就是刚刚在launch.py
里调用的start()
的部分,可以看到,这里分为两种启动方式
- API启动
- webUI启动
API启动
# webui.py
def api_only():
initialize()
app = FastAPI()
setup_middleware(app)
api = create_api(app)
modules.script_callbacks.app_started_callback(None, app)
print(f"Startup time: {startup_timer.summary()}.")
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
其中,在initialize
中,对于模型等一些基本参数和策略进行了初始化,先不一一分析,后面申请的API,之后启动api等,逻辑较为简单,也不过分深入。
WebUI启动
Gradio
SD基于开源项目Gradio搭建,这里有快速搭建的教程: Gradio:轻松实现AI算法可视化部署 - 知乎 (zhihu.com)
这里需要注意gradio默认的端口号为7860
webui窗口的构建
def webui():
launch_api = cmd_opts.api
initialize()
while 1:
if shared.opts.clean_temp_dir_at_start:
ui_tempdir.cleanup_tmpdr()
startup_timer.record("cleanup temp dir")
modules.script_callbacks.before_ui_callback()
startup_timer.record("scripts before_ui_callback")
shared.demo = modules.ui.create_ui()
startup_timer.record("create ui")
if not cmd_opts.no_gradio_queue:
shared.demo.queue(64)
gradio_auth_creds = list(get_gradio_auth_creds()) or None
app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
server_name=server_name,
server_port=cmd_opts.port,
ssl_keyfile=cmd_opts.tls_keyfile,
ssl_certfile=cmd_opts.tls_certfile,
ssl_verify=cmd_opts.disable_tls_verify,
debug=cmd_opts.gradio_debug,
auth=gradio_auth_creds,
inbrowser=cmd_opts.autolaunch and os.getenv('SD_WEBUI_RESTARTING ') != '1',
prevent_thread_lock=True,
allowed_paths=cmd_opts.gradio_allowed_path,
app_kwargs={
"docs_url": "/docs",
"redoc_url": "/redoc",
},
)
if cmd_opts.add_stop_route:
app.add_route("/_stop", stop_route, methods=["POST"])
# after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False
startup_timer.record("gradio launch")
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
# running web ui and do whatever the attacker wants, including installing an extension and
# running its code. We disable this here. Suggested by RyotaK.
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
setup_middleware(app)
modules.progress.setup_progress_api(app)
modules.ui.setup_ui_api(app)
if launch_api:
create_api(app)
ui_extra_networks.add_pages_to_demo(app)
startup_timer.record("add APIs")
with startup_timer.subcategory("app_started_callback"):
modules.script_callbacks.app_started_callback(shared.demo, app)
timer.startup_record = startup_timer.dump()
print(f"Startup time: {startup_timer.summary()}.")
if cmd_opts.subpath:
redirector = FastAPI()
redirector.get("/")
gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
try:
while True:
server_command = shared.state.wait_for_server_command(timeout=5)
if server_command:
if server_command in ("stop", "restart"):
break
else:
print(f"Unknown server command: {server_command}")
except KeyboardInterrupt:
print('Caught KeyboardInterrupt, stopping...')
server_command = "stop"
if server_command == "stop":
print("Stopping server...")
# If we catch a keyboard interrupt, we want to stop the server and exit.
shared.demo.close()
break
print('Restarting UI...')
shared.demo.close()
time.sleep(0.5)
startup_timer.reset()
modules.script_callbacks.app_reload_callback()
startup_timer.record("app reload callback")
modules.script_callbacks.script_unloaded_callback()
startup_timer.record("scripts unloaded callback")
initialize_rest(reload_script_modules=True)
这里的create_ui
方法以及launch
方法,就是构建了UI界面的方法,这里稍微简单看一下UI界面对应的内容:
图中的with gr.Blocks
所对应的就是WebUI内的各个TabBar,比如txt2img
,img2img
等
在这里随便打开一个代码块,比如这里来看txt2img
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
from modules import ui_extra_networks
extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')
with gr.Row().style(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"):
modules.scripts.scripts_txt2img.prepare_ui()
for category in ordered_ui_categories():
if category == "sampler":
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
elif category == "dimensions":
with FormRow():
with gr.Column(elem_id="txt2img_column_size", scale=4):
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
if opts.dimensions_and_batch_together:
with gr.Column(elem_id="txt2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
elif category == "cfg":
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
elif category == "seed":
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
elif category == "checkboxes":
with FormRow(elem_classes="checkboxes-row", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
elif category == "hires_fix":
with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
with gr.Column(scale=80):
with gr.Row():
hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
with gr.Column(scale=80):
with gr.Row():
hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
elif category == "batch":
if not opts.dimensions_and_batch_together:
with FormRow(elem_id="txt2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
elif category == "override_settings":
with FormRow(elem_id="txt2img_override_settings_row") as row:
override_settings = create_override_settings_dropdown('txt2img', row)
elif category == "scripts":
with FormGroup(elem_id="txt2img_script_container"):
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
else:
modules.scripts.scripts_txt2img.setup_ui_for_section(category)
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
for component in hr_resolution_preview_inputs:
event = component.release if isinstance(component, gr.Slider) else component.change
event(
fn=calc_resolution_hires,
inputs=hr_resolution_preview_inputs,
outputs=[hr_final_resolution],
show_progress=False,
)
event(
None,
_js="onCalcResolutionHires",
inputs=hr_resolution_preview_inputs,
outputs=[],
show_progress=False,
)
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
txt2img_args = dict(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
_js="submit",
inputs=[
dummy_component,
txt2img_prompt,
txt2img_negative_prompt,
txt2img_prompt_styles,
steps,
sampler_index,
restore_faces,
tiling,
batch_count,
batch_size,
cfg_scale,
seed,
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
height,
width,
enable_hr,
denoising_strength,
hr_scale,
hr_upscaler,
hr_second_pass_steps,
hr_resize_x,
hr_resize_y,
hr_sampler_index,
hr_prompt,
hr_negative_prompt,
override_settings,
] + custom_inputs,
outputs=[
txt2img_gallery,
generation_info,
html_info,
html_log,
],
show_progress=False,
)
txt2img_prompt.submit(**txt2img_args)
submit.click(**txt2img_args)
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
restore_progress_button.click(
fn=progress.restore_progress,
_js="restoreProgressTxt2img",
inputs=[dummy_component],
outputs=[
txt2img_gallery,
generation_info,
html_info,
html_log,
],
show_progress=False,
)
txt_prompt_img.change(
fn=modules.images.image_data,
inputs=[
txt_prompt_img
],
outputs=[
txt2img_prompt,
txt_prompt_img
],
show_progress=False,
)
enable_hr.change(
fn=lambda x: gr_show(x),
inputs=[enable_hr],
outputs=[hr_options],
show_progress = False,
)
txt2img_paste_fields = [
(txt2img_prompt, "Prompt"),
(txt2img_negative_prompt, "Negative prompt"),
(steps, "Steps"),
(sampler_index, "Sampler"),
(restore_faces, "Face restoration"),
(cfg_scale, "CFG scale"),
(seed, "Seed"),
(width, "Size-1"),
(height, "Size-2"),
(batch_size, "Batch size"),
(subseed, "Variation seed"),
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
(txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
(hr_scale, "Hires upscale"),
(hr_upscaler, "Hires upscaler"),
(hr_second_pass_steps, "Hires steps"),
(hr_resize_x, "Hires resize-1"),
(hr_resize_y, "Hires resize-2"),
(hr_sampler_index, "Hires sampler"),
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" else gr.update()),
(hr_prompt, "Hires prompt"),
(hr_negative_prompt, "Hires negative prompt"),
(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
*modules.scripts.scripts_txt2img.infotext_fields
]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None,
))
txt2img_preview_params = [
txt2img_prompt,
txt2img_negative_prompt,
steps,
sampler_index,
cfg_scale,
seed,
width,
height,
]
token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
可以看到,这里的嵌套是比较乱的,但是好在,每个组件的部分名字都和webui的按钮相对应,所以这里可以通过关键词或者js等方式进行检索。