Lucidrains-系列项目源码解析-四十四-

Lucidrains 系列项目源码解析(四十四)

.\lucidrains\toolformer-pytorch\toolformer_pytorch\tools.py

# 导入所需的库
import os

# 尝试导入所需的库,如果导入失败则输出错误信息并退出程序
try:
    # 从dotenv库中导入load_dotenv函数
    from dotenv import load_dotenv
    load_dotenv()

    # 导入requests、calendar、wolframalpha、datetime、AutoModelForSeq2SeqLM、AutoTokenizer、pow、truediv、mul、add、sub等库
    import requests
    import calendar
    import wolframalpha
    import datetime
    from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
    from operator import pow, truediv, mul, add, sub

    # 可选导入
    from googleapiclient.discovery import build

# 如果导入失败,则输出错误信息并退出程序
except ImportError:
    print('please run `pip install tools-requirements.txt` first at project directory')
    exit()

'''
Calendar

使用Python的datetime和calendar库来获取当前日期。

input - 无

output - 一个字符串,表示当前日期。
'''
def Calendar():
    # 获取当前时间
    now = datetime.datetime.now()
    # 返回当前日期的字符串表示
    return f'Today is {calendar.day_name[now.weekday()]}, {calendar.month_name[now.month]} {now.day}, {now.year}.'


'''
Wikipedia Search

使用ColBERTv2来检索维基百科文档。

input_query - 一个字符串,输入查询(例如"what is a dog?")
k - 要检索的文档数量

output - 一个字符串列表,每个字符串是一个维基百科文档

改编自Stanford的DSP: https://github.com/stanfordnlp/dsp/
也可参考: https://github.com/lucabeetz/dsp
'''
class ColBERTv2:
    def __init__(self, url: str):
        self.url = url

    def __call__(self, query, k=10):
        topk = colbertv2_get_request(self.url, query, k)

        topk = [doc['text'] for doc in topk]
        return topk

# 发送ColBERTv2请求
def colbertv2_get_request(url: str, query: str, k: int):
    payload = {'query': query, 'k': k}
    res = requests.get(url, params=payload)

    topk = res.json()['topk'][:k]
    return topk

# 维基百科搜索函数
def WikiSearch(
    input_query: str,
    url: str = 'http://ec2-44-228-128-229.us-west-2.compute.amazonaws.com:8893/api/search',
    k: int = 10
):
    retrieval_model = ColBERTv2(url)
    output = retrieval_model(input_query, k)
    return output

'''
Machine Translation - NLLB-600M

使用HuggingFace的transformers库将输入查询翻译成英文。

input_query - 一个字符串,输入查询(例如"what is a dog?")

output - 一个字符串,翻译后的输入查询。
'''
def MT(input_query: str, model_name: str = "facebook/nllb-200-distilled-600M"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    input_ids = tokenizer(input_query, return_tensors='pt')
    outputs = model.generate(
        **input_ids,
        forced_bos_token_id=tokenizer.lang_code_to_id["eng_Latn"], 
        )
    output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    return output


'''
Calculator

计算数学表达式的结果。

input_query - 一个字符串,输入的数学表达式(例如"400/1400")

output - 一个浮点数,计算结果

改编自: https://levelup.gitconnected.com/3-ways-to-write-a-calculator-in-python-61642f2e4a9a 
'''
def Calculator(input_query: str):
    operators = {
        '+': add,
        '-': sub,
        '*': mul,
        '/': truediv
        }
    if input_query.isdigit():
        return float(input_query)
    for c in operators.keys():
        left, operator, right = input_query.partition(c)
        if operator in operators:
            return round(operators[operator](Calculator(left), Calculator(right)), 2)


# 其他可选工具


'''
Wolfram Alpha Calculator

pip install wolframalpha

使用Wolfram Alpha API计算输入查询。

input_query - 一个字符串,输入查询(例如"what is 2 + 2?")

output - 一个字符串,输入查询的答案

wolfarm_alpha_appid - 你的Wolfram Alpha API密钥
'''
def WolframAlphaCalculator(input_query: str):
    wolfram_alpha_appid = os.environ.get('WOLFRAM_ALPHA_APPID')
    wolfram_client = wolframalpha.Client(wolfram_alpha_appid)
    res = wolfram_client.query(input_query)
    assumption = next(res.pods).text
    answer = next(res.results).text
    return f'Assumption: {assumption} \nAnswer: {answer}'


'''
Google Search

使用Google的自定义搜索API来检索Google搜索结果。

input_query - 要搜索的查询。
# The number of results to return for the Google Custom Search API
num_results - The number of results to return.
# Your Google API key for accessing Google Custom Search API
api_key - Your Google API key.
# Your Google Custom Search Engine ID for identifying the custom search engine
cse_id - Your Google Custom Search Engine ID.

# A function to perform a custom search using Google Custom Search API
# Returns a list of dictionaries, each dictionary representing a Google Search result
'''
def custom_search(query, api_key, cse_id, **kwargs):
    # Build a service object for the Google Custom Search API
    service = build("customsearch", "v1", developerKey=api_key)
    # Execute the search query and retrieve the results
    res = service.cse().list(q=query, cx=cse_id, **kwargs).execute()
    return res['items']

# A function to perform a Google search using the custom_search function
def google_search(input_query: str, num_results: int = 10):
    # Retrieve Google API key and Custom Search Engine ID from environment variables
    api_key = os.environ.get('GOOGLE_API_KEY')
    cse_id = os.environ.get('GOOGLE_CSE_ID')

    metadata_results = []
    # Perform custom search using custom_search function
    results = custom_search(input_query, num=num_results, api_key=api_key, cse_id=cse_id)
    # Extract relevant metadata from search results
    for result in results:
        metadata_result = {
            "snippet": result["snippet"],
            "title": result["title"],
            "link": result["link"],
        }
        metadata_results.append(metadata_result)
    return metadata_results

'''
Bing Search

Uses Bing's Custom Search API to retrieve Bing Search results.

input_query: The query to search for.
bing_subscription_key: Your Bing API key.
num_results: The number of results to return.

output: A list of dictionaries, each dictionary is a Bing Search result
'''
# A function to retrieve Bing search results using Bing's Custom Search API
def _bing_search_results(
    search_term: str,
    bing_subscription_key: str,
    count: int,
    url: str = "https://api.bing.microsoft.com/v7.0/search"
):
    headers = {"Ocp-Apim-Subscription-Key": bing_subscription_key}
    params = {
        "q": search_term,
        "count": count,
        "textDecorations": True,
        "textFormat": "HTML",
    }
    # Make a GET request to Bing API to retrieve search results
    response = requests.get(
        url, headers=headers, params=params
    )
    response.raise_for_status()
    search_results = response.json()
    return search_results["webPages"]["value"]

# A function to perform a Bing search using the _bing_search_results function
def bing_search(
    input_query: str,
    num_results: int = 10
):
    # Retrieve Bing API key from environment variables
    bing_subscription_key = os.environ.get("BING_API_KEY")
    metadata_results = []
    # Perform Bing search using _bing_search_results function
    results = _bing_search_results(input_query, bing_subscription_key, count=num_results)
    # Extract relevant metadata from search results
    for result in results:
        metadata_result = {
            "snippet": result["snippet"],
            "title": result["name"],
            "link": result["url"],
        }
        metadata_results.append(metadata_result)
    return metadata_results

# Main function to demonstrate the usage of various search functions
if __name__ == '__main__':
 
    print(Calendar()) # Outputs a string, the current date

    print(Calculator('400/1400')) # For Optional Basic Calculator

    print(WikiSearch('What is a dog?')) # Outputs a list of strings, each string is a Wikipedia document

    print(MT("Un chien c'est quoi?")) # What is a dog?

    # Optional Tools

    print(WolframAlphaCalculator('What is 2 + 2?')) # 4

    print(google_search('What is a dog?')) 
    # Outputs a list of dictionaries, each dictionary is a Google Search result

    print(bing_search('What is a dog?')) 
    # Outputs a list of dictionaries, each dictionary is a Bing Search result

.\lucidrains\toolformer-pytorch\toolformer_pytorch\__init__.py

# 从 toolformer_pytorch.palm 模块中导入 PaLM 类
from toolformer_pytorch.palm import PaLM

# 从 toolformer_pytorch.toolformer_pytorch 模块中导入以下函数和类
from toolformer_pytorch.toolformer_pytorch import (
    Toolformer,  # 导入 Toolformer 类
    filter_tokens_with_api_response,  # 导入 filter_tokens_with_api_response 函数
    sample,  # 导入 sample 函数
    sample_with_api_call,  # 导入 sample_with_api_call 函数
    has_api_calls,  # 导入 has_api_calls 函数
    invoke_tools,  # 导入 invoke_tools 函数
    replace_all_but_first  # 导入 replace_all_but_first 函数
)

TPDNE (wip)

Thispersondoesnotexist went down, so this time, while building it back up, I am going to open source all of it. I'll try to make it modular enough so anyone can deploy their own ever-dreaming GAN (or soon to be 1-2 step DDPM) to be public facing

I may also take some time to do something I've always wanted. To 'Perfuse' my dog into the machine and have it dream her up forever to the public.

Explained

The site is hosted on Hetzner on a 100$ / month GPU server. Images are generated live, so people, try as they might, cannot exhaust the amount of faces they experience. Through this, they gain an intuition for how vast the latent space of these neural networks are. It also allowed me to explain it to laypeople as having an 'artificial intelligence endlessly dreaming', without it having to be an exaggeration.

How was this feasible without scaling issues? Well, the site is actually a magic trick. Each user, when refreshing the page, actually sees the same image at any point in time. Images are replaced every 250ms, below the human reaction time. By the time the user studies the face and refreshes, the next face will be there, but it is the same face that everyone experiences around the world at the same time.

The model itself was trained by Tero Karras under the name StyleGAN 2.

Install

$ pip install TPDNE-utils

Usage

from TPDNE_utils import sample_image_and_save_repeatedly

# some function that returns a sampled image in the form of a 3 dimensional ndarray

def generate_image():
    import numpy as np
    return np.random.randn(1024, 1024, 3)

# saves a new sampled image every 250ms as out/sampled.jpeg

sample_image_and_save_repeatedly(generate_image, 'out/sampled')

# use nginx to serve out/sampled.jpeg
# optionally put behind cloudflare

Todo

Citations

@inproceedings{Karras2020ada,
    title     = {Training Generative Adversarial Networks with Limited Data},
    author    = {Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila},
    booktitle = {Proc. NeurIPS},
    year      = {2020}
}

.\lucidrains\TPDNE\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的信息
setup(
  name = 'TPDNE-utils',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.0.11',  # 版本号
  license='MIT',  # 许可证
  description = 'TPDNE',  # 描述
  include_package_data = True,  # 包含所有数据文件
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/TPDNE',  # 项目链接
  keywords = [
    'thispersondoesnotexist'  # 关键词
  ],
  install_requires = [  # 安装依赖
    'beartype',
    'einops>=0.6',
    'jinja2',
    'numpy',
    'pillow'
  ],
  classifiers=[  # 分类
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\TPDNE\TPDNE_utils\tpdne.py

# 导入必要的库
import os
import sys
import numpy as np
from time import time, sleep
from pathlib import Path
from functools import wraps
from PIL import Image

# 导入第三方库
from beartype import beartype
from beartype.typing import Callable, Optional
from einops import rearrange, repeat
from jinja2 import Environment, FileSystemLoader

# 获取当前脚本路径和父目录
script_path = Path(__file__)
current_dir = script_path.parents[0]
# 设置模板环境
environment = Environment(loader = FileSystemLoader(str(current_dir)))

# 获取模板文件
nginx_template = environment.get_template('nginx.conf.tmpl')
systemd_service_template = environment.get_template('tpdne.service.tmpl')

# 定义辅助函数
def exists(val):
    return val is not None

# 处理图像张量的函数
def auto_handle_image_tensor(t):
    if t.ndim == 4:
        t = t[0]  # 假设批次是第一个维度并取第一个样本

    if t.ndim == 2:
        t = rearrange(t, 'h w -> h w 1')  # 假设是灰度图像

    if t.shape[0] <= 3:
        t = rearrange(t, 'c h w -> h w c')  # 通道在前

    assert t.shape[-1] <= 3, 'image tensor must be returned in the shape (height, width, channels), where channels is either 3 or 1'

    if t.shape[-1] == 1:
        t = repeat(t, 'h w 1 -> h w c', c = 3)  # 处理单通道图像

    # 处理缩放
    if t.dtype == np.float:
        has_negatives = np.any(t < 0)

        if has_negatives:
            t = t * 127.5 + 128
        else:
            t = t * 255

        t = t.astype(np.uint8)

    return t.clip(0, 255)

# 主函数
@beartype
def sample_image_and_save_repeatedly(
    fn: Callable[..., np.ndarray],         # 返回形状为 (3, <width>, <height>) 的数组的函数
    output_path: str = './out/random',     # 输出图像的路径,不包括扩展名(将保存为 webp 格式)
    *,
    call_every_ms: int = 250,              # 采样频率
    tmp_dir: str = '/tmp',                 # 存储临时图像的目录
    num_rotated_tmp_images: int = 10,
    image_format: str = 'jpeg',
    verbose: bool = True,
    quality = 99,
    resize_image_to: Optional[int] = None,
    generate_favicon: bool = True,
    favicon_size: int = 32,
    generate_nginx_conf: bool = True,
    symbolic_link_nginx_conf: bool = True,
    nginx_sites_available_path: str = '/etc/nginx/sites-available',
    nginx_conf_filename = 'default',
    generate_systemd_service_conf: bool = False,
    systemd_service_path: str = '/etc/systemd/system',
    systemd_service_name = 'tpdne',
    domain_name = '_'
):
    assert 0 < quality <= 100
    assert favicon_size in {16, 32}
    assert image_format in {'jpeg', 'png', 'webp'}

    tmp_dir = Path(tmp_dir)
    output_path = Path(output_path)

    assert output_path.suffix == '', 'output path suffix will be automatically determined by `image_format` keyword arg'

    output_path = output_path.with_suffix(f'.{image_format}')

    call_every_seconds = call_every_ms / 1000

    assert tmp_dir.is_dir()
    root = output_path.parents[0]
    root.mkdir(parents = True, exist_ok = True)

    tmp_image_index = 0

    # 链接 nginx
    if generate_nginx_conf:
        nginx_sites_path = Path(nginx_sites_available_path)
        nginx_sites_conf_path = nginx_sites_path / nginx_conf_filename

        assert nginx_sites_path.is_dir()

        nginx_conf_text = nginx_template.render(
            root = str(root.resolve()),
            index = output_path.name,
            server_name = domain_name
        )

        tmp_conf_path = Path(tmp_dir / 'nginx.server.conf')
        tmp_conf_path.write_text(nginx_conf_text)

        print(f'nginx server conf generated at {str(tmp_conf_path)}')

        if symbolic_link_nginx_conf:
            os.system(f'ln -nfs {str(tmp_conf_path)} {nginx_sites_conf_path}')

            print(f'nginx conf linked to {nginx_sites_conf_path}\nrun `systemctl reload nginx` for it to be in effect')
    # 如果需要生成 systemd 服务配置文件,并且当前不是在 systemd 中启动
    if generate_systemd_service_conf and not exists(os.getenv('LAUNCHED_FROM_SYSTEMD', None)):

        # 设置 systemd 服务路径
        systemd_service_path = Path(systemd_service_path)
        # 设置 systemd 服务配置文件路径
        systemd_service_conf_path = systemd_service_path / f'{systemd_service_name}.service'

        # 断言 systemd 服务路径是一个目录
        assert systemd_service_path.is_dir()

        # 使用 systemd 服务模板渲染 systemd 配置文本
        systemd_conf_text = systemd_service_template.render(
            working_directory = str(current_dir.resolve()),
            python_executable = sys.executable,
            script_path = str(script_path.resolve())
        )

        # 创建临时服务路径,写入 systemd 配置文本
        tmp_service_path = Path(tmp_dir / 'tpdne.services')
        tmp_service_path.write_text(systemd_conf_text)

        # 创建符号链接,将临时服务路径链接到 systemd 服务配置文件路径
        os.system(f'ln -nfs {str(tmp_service_path)} {str(systemd_service_conf_path)}')

        # 打印提示信息
        print(f'service {systemd_service_name}.service created at {str(systemd_service_conf_path)}')
        print(f'run `systemctl enable {systemd_service_name}.service` to start this script')
        print(f'then run `systemctl status {systemd_service_name}.service` to see the status')
        # 退出程序
        exit()

    # 在一个无限循环中调用函数 `fn`
    while True:
        start = time()
        # 调用函数 `fn` 获取图像张量
        image_tensor = fn()

        # 对图像张量进行处理
        image_tensor = auto_handle_image_tensor(image_tensor)

        # 计算临时图像索引
        tmp_image_index = (tmp_image_index + 1) % num_rotated_tmp_images
        tmp_path = str(tmp_dir / f'{tmp_image_index}.{image_format}')

        # 使用 PIL 创建图像对象
        pil_image = Image.fromarray(image_tensor, 'RGB')

        # 如果存在 resize_image_to 参数,对图像进行缩放
        if exists(resize_image_to):
            pil_image = pil_image.resize((resize_image_to, resize_image_to))

        # 根据图像格式设置不同的参数
        image_save_kwargs = dict()

        if image_format == 'jpeg':
            image_save_kwargs = dict(optimize = True, progressive = True)
        elif image_format == 'webp':
            image_save_kwargs = dict(format = 'webp')

        # 保存图像到临时路径
        pil_image.save(tmp_path, quality = quality, **image_save_kwargs)

        # 创建符号链接,将临时图像路径链接到输出路径
        os.system(f'ln -nfs {tmp_path} {output_path}')

        # 如果需要生成 favicon
        if generate_favicon:
            tmp_favicon_path = str(tmp_dir / f'favicon_{tmp_image_index}.png')
            output_favicon_path = output_path.parents[0] / 'favicon.png'

            # 缩小图像为 favicon 大小
            small_pil_image = pil_image.resize((favicon_size, favicon_size))
            small_pil_image.save(tmp_favicon_path)
            os.system(f'ln -nfs {tmp_favicon_path} {output_favicon_path}')

        # 计算执行时间
        elapsed = time() - start

        # 如果 verbose 为 True,打印执行时间和路径信息
        if verbose:
            print(f'{elapsed:.3f}s - tmp image at {tmp_path}, output image at {output_path}')

        # 确保至少每隔 `call_every_seconds` 秒生成一次图像
        if elapsed >= call_every_seconds:
            continue

        # 休眠直到下一次生成图像的时间点
        sleep(call_every_seconds - elapsed)

.\lucidrains\TPDNE\TPDNE_utils\__init__.py

# 从 TPDNE_utils.tpdne 模块中导入 sample_image_and_save_repeatedly 函数
from TPDNE_utils.tpdne import sample_image_and_save_repeatedly

trRosetta - Pytorch

Implementation of trRosetta and trDesign for Pytorch, made into a convenient package, for protein structure prediction and design. The concept of trDesign will also be abstracted into a wrapper in this repository, so that it can be applied to Alphafold2 once it is replicated. Please join the efforts there if you would like to see this happen!

The original repository can be found here

Update - Xander has released trDesign for Pytorch!

Install

$ pip install tr-rosetta-pytorch

Usage

As a command-line tool, to run a structure prediction

$ tr_rosetta <input-file.a3m>

Code

import torch
from tr_rosetta_pytorch import trRosettaNetwork

model = trRosettaNetwork(
    filters = 64,
    kernel = 3,
    num_layers = 61
).cuda()

x = torch.randn(1, 526, 140, 140).cuda()

theta, phi, distance, omega = model(x)

Citations

@article {Yang1496,
    author = {Yang, Jianyi and Anishchenko, Ivan and Park, Hahnbeom and Peng, Zhenling and Ovchinnikov, Sergey and Baker, David},
    title = {Improved protein structure prediction using predicted interresidue orientations},
    URL = {https://www.pnas.org/content/117/3/1496},
    eprint = {https://www.pnas.org/content/117/3/1496.full.pdf},
    journal = {Proceedings of the National Academy of Sciences}
}
@article {Anishchenko2020.07.22.211482,
    author = {Anishchenko, Ivan and Chidyausiku, Tamuka M. and Ovchinnikov, Sergey and Pellock, Samuel J. and Baker, David},
    title = {De novo protein design by deep network hallucination},
    URL = {https://www.biorxiv.org/content/early/2020/07/23/2020.07.22.211482},
    eprint = {https://www.biorxiv.org/content/early/2020/07/23/2020.07.22.211482.full.pdf},
    journal = {bioRxiv}
}

.\lucidrains\tr-rosetta-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的信息
setup(
  name = 'tr-rosetta-pytorch',  # 包的名称
  packages = find_packages(),  # 查找所有包
  include_package_data = True,  # 包含所有数据文件
  entry_points={  # 设置入口点
    'console_scripts': [  # 控制台脚本
      'tr_rosetta = tr_rosetta_pytorch.cli:predict',  # 脚本名称和执行函数
    ],
  },
  version = '0.0.3',  # 版本号
  license='MIT',  # 许可证
  description = 'trRosetta - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/tr-rosetta-pytorch',  # 项目链接
  keywords = [  # 关键词
    'artificial intelligence',
    'protein folding',
    'protein design'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'fire',
    'numpy',
    'torch>=1.6'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\tr-rosetta-pytorch\tr_rosetta_pytorch\cli.py

# 导入必要的库
import fire
import torch
import tarfile
import numpy as np
from pathlib import Path

# 导入自定义模块
from tr_rosetta_pytorch.tr_rosetta_pytorch import trRosettaNetwork
from tr_rosetta_pytorch.utils import preprocess, d

# 定义路径常量
CURRENT_PATH = Path(__file__).parent
DEFAULT_MODEL_PATH = CURRENT_PATH / 'models'
MODEL_PATH =  DEFAULT_MODEL_PATH / 'models.tar.gz'
MODEL_FILES = [*Path(DEFAULT_MODEL_PATH).glob('*.pt')]

# 如果模型文件未解压,则解压
if len(MODEL_FILES) == 0:
    tar = tarfile.open(str(MODEL_PATH))
    tar.extractall(DEFAULT_MODEL_PATH)
    tar.close()

# 预测函数
@torch.no_grad()
def get_ensembled_predictions(input_file, output_file=None, model_dir=DEFAULT_MODEL_PATH):
    # 创建 trRosettaNetwork 实例
    net = trRosettaNetwork()
    # 预处理输入文件
    i = preprocess(input_file)

    # 如果未指定输出文件,则根据输入文件生成默认输出文件名
    if output_file is None:
        input_path = Path(input_file)
        output_file = f'{input_path.parents[0] / input_path.stem}.npz'

    outputs = []
    model_files = [*Path(model_dir).glob('*.pt')]

    # 如果找不到模型文件,则抛出异常
    if len(model_files) == 0:
        raise 'No model files can be found'

    # 遍历模型文件,加载模型并进行预测
    for model_file in model_files:
        net.load_state_dict(torch.load(model_file, map_location=torch.device(d())))
        net.to(d()).eval()
        output = net(i)
        outputs.append(output)

    # 对模型输出进行平均处理
    averaged_outputs = [torch.stack(model_output).mean(dim=0).cpu().numpy().squeeze(0).transpose(1,2,0) for model_output in zip(*outputs)]
    # 创建包含预测结果的字典
    output_dict = dict(zip(['theta', 'phi', 'dist', 'omega'], averaged_outputs))
    # 保存预测结果到输出文件
    np.savez_compressed(output_file, **output_dict)
    print(f'predictions for {input_file} saved to {output_file}')

# 定义命令行接口
def predict():
    fire.Fire(get_ensembled_predictions)

.\lucidrains\tr-rosetta-pytorch\tr_rosetta_pytorch\tr_rosetta_pytorch.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch.nn.functional 中导入 F 模块
import torch.nn.functional as F

# 定义 ELU 激活函数
def elu():
    return nn.ELU(inplace=True)

# 定义 Instance Normalization 层
def instance_norm(filters, eps=1e-6, **kwargs):
    return nn.InstanceNorm2d(filters, affine=True, eps=eps, **kwargs)

# 定义卷积层
def conv2d(in_chan, out_chan, kernel_size, dilation=1, **kwargs):
    # 计算填充大小
    padding = dilation * (kernel_size - 1) // 2
    return nn.Conv2d(in_chan, out_chan, kernel_size, padding=padding, dilation=dilation, **kwargs)

# 定义 trRosettaNetwork 类,继承自 nn.Module
class trRosettaNetwork(nn.Module):
    # 初始化函数
    def __init__(self, filters=64, kernel=3, num_layers=61):
        super().__init__()
        self.filters = filters
        self.kernel = kernel
        self.num_layers = num_layers

        # 第一个块
        self.first_block = nn.Sequential(
            conv2d(442 + 2 * 42, filters, 1),
            instance_norm(filters),
            elu()
        )

        # 带有不同扩张率的残差块堆叠
        cycle_dilations = [1, 2, 4, 8, 16]
        dilations = [cycle_dilations[i % len(cycle_dilations)] for i in range(num_layers)]

        self.layers = nn.ModuleList([nn.Sequential(
            conv2d(filters, filters, kernel, dilation=dilation),
            instance_norm(filters),
            elu(),
            nn.Dropout(p=0.15),
            conv2d(filters, filters, kernel, dilation=dilation),
            instance_norm(filters)
        ) for dilation in dilations])

        self.activate = elu()

        # 转换为角度图和距离图
        self.to_prob_theta = nn.Sequential(conv2d(filters, 25, 1), nn.Softmax(dim=1))
        self.to_prob_phi = nn.Sequential(conv2d(filters, 13, 1), nn.Softmax(dim=1))
        self.to_distance = nn.Sequential(conv2d(filters, 37, 1), nn.Softmax(dim=1))
        self.to_prob_bb = nn.Sequential(conv2d(filters, 3, 1), nn.Softmax(dim=1))
        self.to_prob_omega = nn.Sequential(conv2d(filters, 25, 1), nn.Softmax(dim=1))
 
    # 前向传播函数
    def forward(self, x):
        x = self.first_block(x)

        for layer in self.layers:
            x = self.activate(x + layer(x))
        
        prob_theta = self.to_prob_theta(x)      # 角度图 theta
        prob_phi = self.to_prob_phi(x)          # 角度图 phi

        x = 0.5 * (x + x.permute((0,1,3,2)))    # 对称化

        prob_distance = self.to_distance(x)     # 距离图
        # prob_bb = self.to_prob_bb(x)            # beta-链配对(未使用)
        prob_omega = self.to_prob_omega(x)      # 角度图 omega

        return prob_theta, prob_phi, prob_distance, prob_omega

.\lucidrains\tr-rosetta-pytorch\tr_rosetta_pytorch\utils.py

# 导入所需的库
import string
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

# 定义函数d,用于确定张量所在的设备(CPU或CUDA)
def d(tensor=None):
    if tensor is None:
        return 'cuda' if torch.cuda.is_available() else 'cpu'
    return 'cuda' if tensor.is_cuda else 'cpu'

# 解析A3M文件并将字母转换为0到20的整数
def parse_a3m(filename):
    # 创建字母表转换表,将小写字母转换为空格
    table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
    # 读取A3M文件中的序列并进行转换
    seqs = [line.strip().translate(table) for line in open(filename, 'r') if line[0] != '>']
    # 创建氨基酸字母表和MSA矩阵
    alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8)
    msa = np.array([list(s) for s in seqs], dtype='|S1').view(np.uint8)

    # 将字母转换为数字
    for i in range(alphabet.shape[0]):
        msa[msa == alphabet[i]] = i

    # 将所有未知字符视为间隔
    msa[msa > 20] = 20
    return msa

# 将1-hot MSA转换为PSSM
def msa2pssm(msa1hot, w):
    beff = w.sum()
    f_i = (w[:, None, None] * msa1hot).sum(dim=0) / beff + 1e-9
    h_i = (-f_i * torch.log(f_i)).sum(dim=1)
    return torch.cat((f_i, h_i[:, None]), dim=1)

# 根据截断值重新加权MSA
def reweight(msa1hot, cutoff):
    id_min = msa1hot.shape[1] * cutoff
    id_mtx = torch.einsum('ikl,jkl->ij', msa1hot, msa1hot)
    id_mask = id_mtx > id_min
    w = 1. / id_mask.float().sum(dim=-1)
    return w

# 快速DCA(Direct Coupling Analysis)缩减协方差矩阵求逆
def fast_dca(msa1hot, weights, penalty = 4.5):
    device = msa1hot.device
    nr, nc, ns = msa1hot.shape
    x = msa1hot.view(nr, -1)
    num_points = weights.sum() - torch.sqrt(weights.mean())

    mean = (x * weights[:, None]).sum(dim=0, keepdims=True) / num_points
    x = (x - mean) * torch.sqrt(weights[:, None])

    cov = (x.t() @ x) / num_points
    cov_reg = cov + torch.eye(nc * ns).to(device) * penalty / torch.sqrt(weights.sum())

    inv_cov = torch.inverse(cov_reg)
    x1 = inv_cov.view(nc, ns, nc, ns)
    x2 = x1.transpose(1, 2).contiguous()
    features = x2.reshape(nc, nc, ns * ns)

    x3 = torch.sqrt((x1[:, :-1, :, :-1] ** 2).sum(dim=(1, 3))) * (1 - torch.eye(nc).to(device))
    apc = x3.sum(dim=0, keepdims=True) * x3.sum(dim=1, keepdims=True) / x3.sum()
    contacts = (x3 - apc) * (1 - torch.eye(nc).to(device))
    return torch.cat((features, contacts[:, :, None]), dim=2)

# 预处理函数,将MSA文件转换为适用于神经网络的输入
def preprocess(msa_file, wmin=0.8, ns=21):
    a3m = torch.from_numpy(parse_a3m(msa_file)).long()
    nrow, ncol = a3m.shape

    msa1hot = F.one_hot(a3m, ns).float().to(d())
    w = reweight(msa1hot, wmin).float().to(d())

    # 1D序列特征
    f1d_seq = msa1hot[0, :, :20].float()
    f1d_pssm = msa2pssm(msa1hot, w)

    f1d = torch.cat((f1d_seq, f1d_pssm), dim=1)
    f1d = f1d[None, :, :].reshape((1, ncol, 42))

    # 2D序列特征
    f2d_dca = fast_dca(msa1hot, w) if nrow > 1 else torch.zeros((ncol, ncol, 442)).float().to(d())
    f2d_dca = f2d_dca[None, :, :, :]

    f2d = torch.cat((
        f1d[:, :, None, :].repeat(1, 1, ncol, 1), 
        f1d[:, None, :, :].repeat(1, ncol, 1, 1),
        f2d_dca
    ), dim=-1)

    f2d = f2d.view(1, ncol, ncol, 442 + 2*42)
    return f2d.permute((0, 3, 2, 1))

.\lucidrains\tr-rosetta-pytorch\tr_rosetta_pytorch\__init__.py

# 从 tr_rosetta_pytorch 模块中导入 trRosettaNetwork 类
from tr_rosetta_pytorch.tr_rosetta_pytorch import trRosettaNetwork

Tranception - Pytorch (wip)

Implementation of Tranception, an attention network, paired with retrieval, that is SOTA for protein fitness prediction. The Transformer architecture is inspired by Primer, and uses ALiBi relative positional encoding

Install

$ pip install tranception-pytorch

Usage

import torch
from tranception_pytorch import Tranception

model = Tranception(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64
)

amino_acids = torch.randint(0, 21, (1, 512))

logits = model(amino_acids) # (1, 512, 21)

Todo

Citations

@article{Notin2022TranceptionPF,
  title   = {Tranception: protein fitness prediction with autoregressive transformers and inference-time retrieval},
  author  = {Pascal Notin and Mafalda Dias and Jonathan Frazer and Javier Marchena-Hurtado and Aidan N. Gomez and Debora S. Marks and Yarin Gal},
  journal = {ArXiv},
  year    = {2022},
  volume  = {abs/2205.13760}
}

.\lucidrains\tranception-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包名
  name = 'tranception-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.0.8',
  # 许可证
  license='MIT',
  # 描述
  description = 'Tranception - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/tranception-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'protein fitness'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.4',
    'einops-exts',
    'torch>=1.6',
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\tranception-pytorch\tranception_pytorch\tranception_pytorch.py

# 导入 math 模块
import math
# 导入 torch 模块
import torch
# 导入 torch.nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 torch 模块中导入 nn、einsum 模块
from torch import nn, einsum
# 从 einops 模块中导入 rearrange 函数
from einops import rearrange
# 从 einops_exts 模块中导入 rearrange_many 函数
from einops_exts import rearrange_many
# 从 einops.layers.torch 模块中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 如果变量存在则返回该变量,否则返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 相对位置偏置

# 自定义类 LearnedAlibiPosBias 继承自 nn.Module
class LearnedAlibiPosBias(nn.Module):
    # 初始化函数
    def __init__(self, heads):
        super().__init__()
        self.heads = heads
        # 计算斜率并转换为张量
        slopes = torch.Tensor(self._get_slopes(heads))
        slopes = rearrange(slopes, 'h -> h 1 1')
        self.slopes = nn.Parameter(slopes)
        # 注册缓冲区 bias
        self.register_buffer('bias', None, persistent = False)

    # 获取相对位置偏置的函数
    def get_bias(self, i, j, device):
        i_arange = torch.arange(i, device = device)
        j_arange = torch.arange(j, device = device)
        bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
        return bias

    # 静态方法,用于获取斜率
    @staticmethod
    def _get_slopes(heads):
        def get_slopes_power_of_2(n):
            start = (2**(-2**-(math.log2(n)-3)))
            ratio = start
            return [start*ratio**i for i in range(n)]

        if math.log2(heads).is_integer():
            return get_slopes_power_of_2(heads)

        closest_power_of_2 = 2 ** math.floor(math.log2(heads))
        return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]

    # 前向传播函数
    def forward(self, qk_sim):
        h, i, j, device = *qk_sim.shape[-3:], qk_sim.device

        if exists(self.bias) and self.bias.shape[-1] >= j:
            return self.bias[..., :i, :j]

        bias = self.get_bias(i, j, device)
        bias = bias * self.slopes

        num_heads_unalibied = h - bias.shape[0]
        bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
        self.register_buffer('bias', bias, persistent = False)

        return bias

# 辅助类

# 自定义类 ReluSquared 继承自 nn.Module
class ReluSquared(nn.Module):
    """ found with neural architecture search in Primer paper """
    # 前向传播函数
    def forward(self, x):
        return F.relu(x) ** 2

# 定义 FeedForward 函数
def FeedForward(dim, mult = 4):
    hidden_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, hidden_dim),
        ReluSquared(),
        nn.Linear(hidden_dim, dim)
    )

# 自定义类 DepthwiseConv1d 继承自 nn.Module
class DepthwiseConv1d(nn.Module):
    # 初始化函数
    def __init__(self, dim, kernel_size, causal = True):
        super().__init__()
        assert (kernel_size % 2) == 1

        self.padding = (kernel_size - 1, 0) if causal else (kernel_size // 2, kernel_size // 2)
        self.conv = nn.Conv1d(dim, dim, kernel_size = kernel_size, groups = dim)

    # 前向传播函数
    def forward(self, x):
        x = F.pad(x, self.padding)
        return self.conv(x)

# 自定义类 Attention 继承自 nn.Module
class Attention(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64,
        causal = False,
        ds_conv_kernel_sizes = (0, 3, 5, 7) # heads were grouped into 4 groups and given a depthwise conv after the queries / keys / values projection
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置头数等于卷积核大小的组数,确保头数大于等于组数且头数能被组数整除
        self.groups = len(ds_conv_kernel_sizes)
        assert heads >= self.groups and (heads % self.groups) == 0, f'heads must be greater than {self.groups} and divisible by {self.groups}'

        # 设置缩放因子为头尺寸的负平方根
        self.scale = dim_head ** -0.5
        # 是否使用因果卷积
        self.causal = causal

        self.heads = heads
        self.heads_per_group = heads // self.groups

        inner_dim = heads * dim_head

        # 对输入进行 LayerNorm
        self.norm = nn.LayerNorm(dim)

        # 用 1x1 卷积层将输入转换为查询、键、值
        self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)

        # 使用不同卷积核大小的深度卷积进行 4 组头的处理
        self.qkv_ds_convs = nn.ModuleList([])

        for _ in range(3): # for queries, keys, values
            ds_convs = nn.ModuleList([])

            for kernel_size in ds_conv_kernel_sizes:
                if kernel_size == 0:
                    ds_convs.append(nn.Identity())
                    continue

                ds_convs.append(DepthwiseConv1d(dim_head * self.heads_per_group, kernel_size, causal = causal))

            self.qkv_ds_convs.append(ds_convs)

        # 为 4 组头学习位置偏置
        self.learned_alibi_pos_biases = nn.ModuleList([LearnedAlibiPosBias(heads = self.heads_per_group) for _ in range(self.groups)])

        # 输出投影
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        device, heads_per_group = x.device, self.heads_per_group

        # 对输入进行 LayerNorm,并重新排列维度
        x = self.norm(x)
        x = rearrange(x, 'b n d -> b d n')

        # 将输入转换为查询、键、值
        q, k, v = self.to_qkv(x).chunk(3, dim = 1)

        # 重新排列查询、键、值的维度
        q, k, v = rearrange_many((q, k, v), 'b (h d) n -> b h d n', h = self.heads)

        # 对分组头应用因果深度卷积
        def apply_causal_ds_conv_to_grouped_heads(args):
            projs, ds_convs = args
            batch = projs.shape[0]

            projs = rearrange_many(projs.split(heads_per_group, dim = 1), 'b h d n -> b (h d) n')
            conv_out = [fn(t) for fn, t in zip(ds_convs, projs)]
            conv_out = map(lambda t: rearrange(t, 'b (h d) n -> b h d n', h = heads_per_group), conv_out)
            conv_out = torch.cat(tuple(conv_out), dim = 1)
            return rearrange(conv_out, 'b h d n -> b h n d')

        q, k, v = map(apply_causal_ds_conv_to_grouped_heads, zip((q, k, v), self.qkv_ds_convs))

        # 缩放和计算相似度
        q = q * self.scale
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # 对 4 组头应用学习的位置偏置
        grouped_sims = sim.split(self.heads // self.groups, dim = 1)
        grouped_sims = [(alibi(sim_group) + sim_group) for alibi, sim_group in zip(self.learned_alibi_pos_biases, grouped_sims)]
        sim = torch.cat(grouped_sims, dim = 1)

        # 因果掩码
        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # 注意力机制
        attn = sim.softmax(dim = -1)
        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # 合并头
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
# 定义一个名为 Tranception 的类,继承自 nn.Module
class Tranception(nn.Module):
    # 初始化函数,接受一系列参数
    def __init__(
        self,
        *,
        dim,  # 特征维度
        depth,  # 模型深度
        num_tokens = 21,  # 标记数量,默认为 21
        heads = 8,  # 多头注意力机制中的头数,默认为 8
        dim_head = 64,  # 每个头的维度,默认为 64
        ff_mult = 4,  # FeedForward 层的倍数,默认为 4
        ds_conv_kernel_sizes = (0, 3, 5, 7),  # 下采样卷积的内核大小,默认为 (0, 3, 5, 7)
        causal = True  # 是否使用因果卷积,默认为 True
    ):
        super().__init__()  # 调用父类的初始化函数
        self.token_emb = nn.Embedding(num_tokens, dim)  # 创建一个标记嵌入层

        self.layers = nn.ModuleList([])  # 创建一个空的模块列表
        for _ in range(depth):  # 根据深度循环
            self.layers.append(nn.ModuleList([  # 向模块列表中添加模块列表
                Attention(dim = dim, heads = heads, dim_head = dim_head, ds_conv_kernel_sizes = ds_conv_kernel_sizes, causal = causal),  # 添加注意力层
                FeedForward(dim, mult = ff_mult)  # 添加前馈神经网络层
            ]))

        self.to_logits = nn.Sequential(  # 创建一个序列模块
            nn.LayerNorm(dim),  # 添加层归一化层
            nn.Linear(dim, num_tokens)  # 添加线性层
        )

    # 前向传播函数,接受输入 x 和掩码 mask,默认为 None
    def forward(
        self,
        x,
        mask = None
    ):
        x = self.token_emb(x)  # 将输入 x 通过标记嵌入层

        for attn, ff in self.layers:  # 遍历模块列表中的模块
            x = attn(x) + x  # 执行注意力层并将结果与输入相加
            x = ff(x) + x  # 执行前馈神经网络层并将结果与输入相加

        return self.to_logits(x)  # 返回经过线性层处理后的结果

.\lucidrains\tranception-pytorch\tranception_pytorch\__init__.py

# 从 tranception_pytorch.tranception_pytorch 模块中导入 Tranception 类
from tranception_pytorch.tranception_pytorch import Tranception

Transformer in Transformer

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch.

AI Coffee Break with Letitia

Install

$ pip install transformer-in-transformer

Usage

import torch
from transformer_in_transformer import TNT

tnt = TNT(
    image_size = 256,       # size of image
    patch_dim = 512,        # dimension of patch token
    pixel_dim = 24,         # dimension of pixel token
    patch_size = 16,        # patch size
    pixel_size = 4,         # pixel size
    depth = 6,              # depth
    num_classes = 1000,     # output number of classes
    attn_dropout = 0.1,     # attention dropout
    ff_dropout = 0.1        # feedforward dropout
)

img = torch.randn(2, 3, 256, 256)
logits = tnt(img) # (2, 1000)

Citations

@misc{han2021transformer,
    title   = {Transformer in Transformer}, 
    author  = {Kai Han and An Xiao and Enhua Wu and Jianyuan Guo and Chunjing Xu and Yunhe Wang},
    year    = {2021},
    eprint  = {2103.00112},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\transformer-in-transformer\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'transformer-in-transformer',  # 包名
  packages = find_packages(),  # 查找所有包
  version = '0.1.2',  # 版本号
  license='MIT',  # 许可证
  description = 'Transformer in Transformer - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/transformer-in-transformer',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformer',
    'image classification'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'torch>=1.6'
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\transformer-in-transformer\transformer_in_transformer\tnt.py

# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch 中导入 nn 和 einsum 模块
from torch import nn, einsum

# 从 einops 中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 einops.layers.torch 中导入 Rearrange 类

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 判断值是否可以被除数整除
def divisible_by(val, divisor):
    return (val % divisor) == 0

# 计算展开后的输出尺寸
def unfold_output_size(image_size, kernel_size, stride, padding):
    return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)

# 类

# 预处理层
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        # 使用 LayerNorm 对输入进行归一化
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        # 对输入进行归一化后,传入下一层处理
        return self.fn(self.norm(x), **kwargs)

# 前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        # 神经网络结构:全连接层 -> GELU 激活函数 -> Dropout -> 全连接层
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        # 前馈神经网络的前向传播
        return self.net(x)

# 注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64,
        dropout = 0.
    ):
        super().__init__()
        inner_dim = heads * dim_head
        self.heads =  heads
        self.scale = dim_head ** -0.5

        # 将输入转换为查询、键、值
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        # 输出层结构:全连接层 -> Dropout
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        b, n, d, h = *x.shape, self.heads
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

        # 计算注意力分数
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
        attn = sim.softmax(dim = -1)

        # 计算输出
        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        return self.to_out(out)

# 主类

class TNT(nn.Module):
    def __init__(
        self,
        *,
        image_size,
        patch_dim,
        pixel_dim,
        patch_size,
        pixel_size,
        depth,
        num_classes,
        channels = 3,
        heads = 8,
        dim_head = 64,
        ff_dropout = 0.,
        attn_dropout = 0.,
        unfold_args = None
    # 初始化函数,设置模型参数
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 检查图像大小是否能被分块大小整除
        assert divisible_by(image_size, patch_size), 'image size must be divisible by patch size'
        # 检查分块大小是否能被像素大小整除
        assert divisible_by(patch_size, pixel_size), 'patch size must be divisible by pixel size for now'

        # 计算分块令牌的数量
        num_patch_tokens = (image_size // patch_size) ** 2

        # 设置模型参数
        self.image_size = image_size
        self.patch_size = patch_size
        self.patch_tokens = nn.Parameter(torch.randn(num_patch_tokens + 1, patch_dim))

        # 设置默认的展开参数
        unfold_args = default(unfold_args, (pixel_size, pixel_size, 0))
        unfold_args = (*unfold_args, 0) if len(unfold_args) == 2 else unfold_args
        kernel_size, stride, padding = unfold_args

        # 计算像素宽度和像素数量
        pixel_width = unfold_output_size(patch_size, kernel_size, stride, padding)
        num_pixels = pixel_width ** 2

        # 定义将像素转换为令牌的模块
        self.to_pixel_tokens = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> (b h w) c p1 p2', p1 = patch_size, p2 = patch_size),
            nn.Unfold(kernel_size = kernel_size, stride = stride, padding = padding),
            Rearrange('... c n -> ... n c'),
            nn.Linear(channels * kernel_size ** 2, pixel_dim)
        )

        # 初始化分块位置编码和像素位置编码
        self.patch_pos_emb = nn.Parameter(torch.randn(num_patch_tokens + 1, patch_dim))
        self.pixel_pos_emb = nn.Parameter(torch.randn(num_pixels, pixel_dim))

        # 创建模型层
        layers = nn.ModuleList([])
        for _ in range(depth):

            # 定义将像素转换为分块的模块
            pixel_to_patch = nn.Sequential(
                nn.LayerNorm(pixel_dim),
                Rearrange('... n d -> ... (n d)'),
                nn.Linear(pixel_dim * num_pixels, patch_dim),
            )

            # 添加模型层
            layers.append(nn.ModuleList([
                PreNorm(pixel_dim, Attention(dim = pixel_dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)),
                PreNorm(pixel_dim, FeedForward(dim = pixel_dim, dropout = ff_dropout)),
                pixel_to_patch,
                PreNorm(patch_dim, Attention(dim = patch_dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)),
                PreNorm(patch_dim, FeedForward(dim = patch_dim, dropout = ff_dropout)),
            ]))

        # 设置模型层和 MLP 头部
        self.layers = layers
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, num_classes)
        )

    # 前向传播函数
    def forward(self, x):
        # 获取输入张量的形状和模型参数
        b, _, h, w, patch_size, image_size = *x.shape, self.patch_size, self.image_size
        # 检查输入的高度和宽度是否能被分块大小整除
        assert divisible_by(h, patch_size) and divisible_by(w, patch_size), f'height {h} and width {w} of input must be divisible by the patch size'

        # 计算分块的数量
        num_patches_h = h // patch_size
        num_patches_w = w // patch_size
        n = num_patches_w * num_patches_h

        # 将输入张量转换为像素令牌和分块令牌
        pixels = self.to_pixel_tokens(x)
        patches = repeat(self.patch_tokens[:(n + 1)], 'n d -> b n d', b = b)

        # 添加分块位置编码和像素位置编码
        patches += rearrange(self.patch_pos_emb[:(n + 1)], 'n d -> () n d')
        pixels += rearrange(self.pixel_pos_emb, 'n d -> () n d')

        # 遍历模型层,进行注意力和前馈计算
        for pixel_attn, pixel_ff, pixel_to_patch_residual, patch_attn, patch_ff in self.layers:

            pixels = pixel_attn(pixels) + pixels
            pixels = pixel_ff(pixels) + pixels

            patches_residual = pixel_to_patch_residual(pixels)

            patches_residual = rearrange(patches_residual, '(b h w) d -> b (h w) d', h = num_patches_h, w = num_patches_w)
            patches_residual = F.pad(patches_residual, (0, 0, 1, 0), value = 0) # cls token gets residual of 0
            patches = patches + patches_residual

            patches = patch_attn(patches) + patches
            patches = patch_ff(patches) + patches

        # 提取分类令牌并通过 MLP 头部进行分类预测
        cls_token = patches[:, 0]
        return self.mlp_head(cls_token)

.\lucidrains\transformer-in-transformer\transformer_in_transformer\__init__.py

# 从transformer_in_transformer包中导入TNT类
from transformer_in_transformer.tnt import TNT

Transframer - Pytorch (wip)

Implementation of Transframer, Deepmind's U-net + Transformer architecture for up to 30 seconds video generation, in Pytorch

The gist of the paper is the usage of a Unet as a multi-frame encoder, along with a regular transformer decoder cross attending and predicting the rest of the frames. The author builds upon his prior work where images are encoded as sparse discrete cosine transform (DCT) sequences.

I will deviate from the implementation in this paper, using a hierarchical autoregressive transformer, and just a regular resnet block in place of the NF-net block (this design choice is just Deepmind reusing their own code, as NF-net was developed at Deepmind by Brock et al).

Update: On further meditation, there is nothing new in this paper except for generative modeling on DCT representations

Appreciation

  • This work would not be possible without the generous sponsorship from Stability AI, as well as my other sponsors

Todo

Citations

@article{Nash2022TransframerAF,
    title   = {Transframer: Arbitrary Frame Prediction with Generative Models},
    author  = {Charlie Nash and Jo{\~a}o Carreira and Jacob Walker and Iain Barr and Andrew Jaegle and Mateusz Malinowski and Peter W. Battaglia},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2203.09494}
}

.\lucidrains\transframer-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'transframer-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.0.1',  # 版本号
  license='MIT',  # 许可证
  description = 'Transframer - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/transframer-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'unets',
    'video generation'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.4',
    'kornia',
    'torch>=1.6',
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\transframer-pytorch\transframer_pytorch\transframer_pytorch.py

# 从 math 模块中导入 sqrt 和 pi 函数
# 从 functools 模块中导入 partial 函数
import torch
# 从 torch.nn.functional 模块中导入 F
import torch.nn.functional as F
# 从 torch.fft 模块中导入 fft 和 irfft 函数
from torch.fft import fft, irfft
# 从 torch 模块中导入 nn 和 einsum 函数
from torch import nn, einsum
# 从 einops 模块中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 kornia.color.ycbcr 模块中导入 rgb_to_ycbcr 和 ycbcr_to_rgb 函数

# helpers

# 定义 exists 函数,判断值是否存在
def exists(val):
    return val is not None

# 定义 default 函数,如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# tensor helpers

# 定义 l2norm 函数,对张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# dct related encoding / decoding functions

# 定义 dct 函数,进行离散余弦变换
# 函数来源于 https://github.com/zh217/torch-dct/blob/master/torch_dct/_dct.py
# 修复了大多数 torch 版本 > 1.9 的问题,使用最新的 fft 和 irfft
def dct(x, norm = None):
    shape, dtype, device = x.shape, x.dtype, x.device
    N = shape[-1]

    x = rearrange(x.contiguous(), '... n -> (...) n')

    v = torch.cat([x[:, ::2], x[:, 1::2].flip((1,))], dim = 1)

    vc = torch.view_as_real(fft(v, dim=1))

    k = -torch.arange(N, dtype = dtype, device = device) * pi / (2 * N)
    k = rearrange(k, 'n -> 1 n')

    v = vc[:, :, 0] * k.cos() - vc[:, :, 1] * k.sin()

    if norm == 'ortho':
        v[:, 0] /= sqrt(N) * 2
        v[:, 1:] /= sqrt(N / 2) * 2

    v *= 2
    return v.view(*shape)

# 定义 idct 函数,进行逆离散余弦变换
def idct(x, norm = None):
    shape, dtype, device = x.shape, x.dtype, x.device
    N = shape[-1]

    x_v = rearrange(x.contiguous(), '... n -> (...) n') / 2

    if norm == 'ortho':
        x_v[:, 0] *= sqrt(N) * 2
        x_v[:, 1:] *= sqrt(N / 2) * 2

    k = torch.arange(N, dtype = dtype, device = device) * pi / (2 * N)
    k = rearrange(k, 'n -> 1 n')
    w_r = torch.cos(k)
    w_i = torch.sin(k)

    v_t_r = x_v
    v_t_i = torch.cat([x_v[:, :1] * 0, -x_v.flip((1,))[:, :-1]], dim = 1)

    v_r = v_t_r * w_r - v_t_i * w_i
    v_i = v_t_r * w_i + v_t_i * w_r

    v = torch.stack((v_r, v_i), dim = -1)

    v = irfft(torch.view_as_complex(v), n = N, dim = 1)
    x = torch.zeros_like(v)
    x[:, ::2] += v[:, :N - (N // 2)]
    x[:, 1::2] += v.flip((1,))[:, :N // 2]

    return x.view(*shape)

# 定义 dct_2d 函数,对二维张量进行离散余弦变换
def dct_2d(x, norm = None):
    dct_ = partial(dct, norm = norm)
    x1 = dct_(x)
    x2 = dct_(rearrange(x1, '... h w -> ...  w h'))
    return rearrange(x2, '... h w -> ... w h')

# 定义 idct_2d 函数,对二维张量进行逆离散余弦变换
def idct_2d(x, norm = None):
    idct_ = partial(idct, norm = norm)
    x1 = idct_(x)
    x2 = idct_(rearrange(x1, '... h w -> ... w h'))
    return rearrange(x2, '... h w -> ... w h')

# 定义 blockify 函数,将张量分块
def blockify(x, block_size = 8):
    assert block_size in {8, 16}
    return rearrange(x, 'b c (h bs1) (w bs2) -> (b h w) c bs1 bs2', bs1 = block_size, bs2 = block_size)

# 定义 deblockify 函数,将分块的张量还原为原始形状
def deblockify(x, h, w, block_size = 8):
    assert block_size in {8, 16}
    return rearrange(x, '(b h w) c bs1 bs2 -> b c (h bs1) (w bs2)', h = h, w = w)

# final functions from rgb -> dct and back

# 定义 images_to_dct 函数,将图像转换为离散余弦变换
def images_to_dct(images):
    raise NotImplementedError

# 定义 dct_to_images 函数,将离散余弦��换转换为图像
def dct_to_images(images):
    raise NotImplementedError

# feedforward

# 定义 FeedForward 类,包含线性层和 GELU 激活函数
def FeedForward(
    dim,
    *,
    mult = 4.
):
    inner_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim, bias = False),
        nn.GELU(),
        nn.LayerNorm(inner_dim),  # from normformer paper
        nn.Linear(inner_dim, dim, bias = False)
    )

# attention, what else?
# here we will use one headed key / values (as described in paper, from Noam Shazeer) - along with cosine sim attention

# 定义 Attention 类,包含多头注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        dim_head = 64,
        heads = 8,
        scale = 10,
        causal = False,
        norm_context = False
    ):
        super().__init__()
        self.heads = heads
        self.scale = scale
        self.causal = causal

        self.norm = nn.LayerNorm(dim)
        self.norm_context = nn.LayerNorm(dim) if norm_context else nn.Identity()

        self.to_q = nn.Linear(dim, dim_head * heads, bias = False)
        self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
        self.to_out = nn.Linear(dim_head * heads, dim, bias = False)
    # 定义一个前向传播函数,接受输入 x,上下文 context 和上下文掩码 context_mask
    def forward(
        self,
        x,
        context = None,
        context_mask = None
    ):
        # 获取头数 h,缩放因子 scale,是否因果 causal,设备信息 device
        h, scale, causal, device = self.heads, self.scale, self.causal, x.device

        # 对输入 x 进行归一化处理
        x = self.norm(x)

        # 如果存在上下文 context,则使用上下文,否则使用输入 x 作为上下文
        context = default(context, x)

        # 将输入 x 转换为查询向量 q,并重新排列维度
        q = self.to_q(x)
        q = rearrange(q, 'b n (h d) -> b h n d', h = h)

        # 如果存在上下文,则对上下文进行归一化处理
        if exists(context):
            context = self.norm_context(context)

        # 将上下文转换为键值对 k, v,并按最后一个维度分割成两部分
        k, v = self.to_kv(context).chunk(2, dim = -1)

        # 对查询向量 q 和键向量 k 进行 L2 归一化
        q, k = map(l2norm, (q, k))

        # 计算查询向量 q 和键向量 k 之间的相似度矩阵 sim
        sim = einsum('b h i d, b j d -> b h i j', q, k) * self.scale

        # 计算掩码值,用于在相似度矩阵中进行掩码操作
        mask_value = -torch.finfo(sim.dtype).max

        # 如果存在上下文掩码,则对相似度矩阵进行掩码操作
        if exists(context_mask):
            context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(context_mask, mask_value)

        # 如果是因果注意力机制,则对相似度矩阵进行因果掩码操作
        if causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, mask_value)

        # 对相似度矩阵进行 softmax 操作,得到注意力权重
        attn = sim.softmax(dim = -1)

        # 根据注意力权重计算输出向量 out
        out = einsum('b h i j, b j d -> b h i d', attn, v)

        # 重新排列输出向量的维度
        out = rearrange(out, 'b h n d -> b n (h d)')
        # 返回输出向量
        return self.to_out(out)
# 定义一个名为 Block 的类,继承自 nn.Module
class Block(nn.Module):
    # 初始化函数,接受输入维度 dim、输出维度 dim_out 和分组数 groups
    def __init__(
        self,
        dim,
        dim_out,
        groups = 8
    ):
        super().__init__()
        # 创建一个卷积层,输入维度为 dim,输出维度为 dim_out,卷积核大小为 3,填充为 1
        self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
        # 创建一个 GroupNorm 层,分组数为 groups,输出维度为 dim_out
        self.norm = nn.GroupNorm(groups, dim_out)
        # 创建一个 SiLU 激活函数层
        self.act = nn.SiLU()

    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 对输入 x 进行卷积操作
        x = self.proj(x)
        # 对卷积结果进行 GroupNorm 操作
        x = self.norm(x)
        # 对 GroupNorm 结果进行 SiLU 激活函数操作
        return self.act(x)

# 定义一个名为 ResnetBlock 的类,继承自 nn.Module
class ResnetBlock(nn.Module):
    # 初始化函数,接受输入维度 dim、输出维度 dim_out 和分组数 groups
    def __init__(
        self,
        dim,
        dim_out,
        groups = 8
    ):
        super().__init__()
        # 创建两个 Block 实例,分别作为 ResNet 块的两个子块
        self.block1 = Block(dim, dim_out, groups = groups)
        self.block2 = Block(dim_out, dim_out, groups = groups)
        # 如果输入维度和输出维度不相等,则创建一个卷积层,否则创建一个恒等映射层
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 对输入 x 进行第一个子块的操作
        h = self.block1(x)
        # 对第一个子块的输出进行第二个子块的操作
        h = self.block2(h)
        # 返回第一个子块的输出与输入 x 经过卷积的结果的和
        return h + self.res_conv(x)

# 定义一个名为 UnetTransformerBlock 的类,继承自 nn.Module
class UnetTransformerBlock(nn.Module):
    # 初始化函数,接受输入维度 dim、注意力头维度 dim_head 和注意力头数 heads
    def __init__(
        self,
        dim,
        *,
        dim_head = 32,
        heads = 8
    ):
        super().__init__()
        # 创建一个 Attention 层,输入维度为 dim,注意力头维度为 dim_head,注意力头数为 heads
        self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads)
        # 创建一个 FeedForward 层,输入维度为 dim
        self.ff = FeedForward(dim = dim)

    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 保存输入 x 的原始形状
        orig_shape = x.shape
        # 将输入 x 重排列为 'b c ...' 的形式
        x = rearrange(x, 'b c ... -> b (...) c')

        # 对输入 x 进行注意力操作并加上原始输入 x
        x = self.attn(x) + x
        # 对加上注意力结果的 x 进行 FeedForward 操作并加上原始输入 x
        x = self.ff(x) + x

        # 将 x 重排列为 'b n c' 的形式,再将其形状恢复为原始形状
        x = rearrange(x, 'b n c -> b c n')
        return x.reshape(*orig_shape)

# 定义一个名为 Unet 的类,继承自 nn.Module
class Unet(nn.Module):
    # 初始化函数,接受输入维度 dim、输出维度 dim_out、注意力参数 attn_kwargs
    def __init__(
        self,
        dim,
        *,
        dim_mults = (1, 2, 3, 4),
        dim_out,
        **attn_kwargs
    ):
        super().__init__()
        # 创建一个输出维度为 dim_out 的卷积层
        self.to_out = nn.Conv2d(dim, dim_out, 1)
        # 计算多层次维度倍增后的维度列表 dims
        dims = [dim, *map(lambda t: t * dim, dim_mults)]
        # 计算每一层次的维度对 dim_pairs
        dim_pairs = tuple(zip(dims[:-1], dims[1:]))
        # 中间维度为 dims 的最后一个元素
        mid_dim = dims[-1]

        # 创建下采样和上采样的模块列表
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])

        # 创建中间的 ResNet 块
        self.mid = ResnetBlock(mid_dim, mid_dim)

        # 遍历每一层次的维度对
        for dim_in, dim_out in dim_pairs:
            # 对每一层次创建下采样模块列表
            self.downs.append(nn.ModuleList([
                ResnetBlock(dim_in, dim_in),
                UnetTransformerBlock(dim_in, **attn_kwargs),
                nn.Conv2d(dim_in, dim_out, 3, 2, 1)
            ]))

            # 对每一层次创建上采样模块列表
            self.ups.insert(0, nn.ModuleList([
                ResnetBlock(dim_out * 2, dim_out),
                UnetTransformerBlock(dim_out, **attn_kwargs),
                nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1)
            ]))

    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 保存每个下采样阶段的隐藏状态
        hiddens = []

        # 对每个下采样阶段的模块进行操作
        for block, attn_block, downsample in self.downs:
            x = block(x)
            x = attn_block(x)
            x = downsample(x)
            hiddens.append(x)

        # 对中间的 ResNet 块进行操作
        x = self.mid(x)

        # 对每个上采样阶段的模块进行操作
        for block, attn_block, upsample in self.ups:
            x = torch.cat((x, hiddens.pop()), dim = 1)
            x = block(x)
            x = attn_block(x)
            x = upsample(x)

        # 对输出进行卷积操作并重排列输出形状
        out = self.to_out(x)
        return rearrange(out, 'b c h w -> b (h w) c')

# 定义一个名为 Transframer 的类,继承自 nn.Module
class Transframer(nn.Module):
    # 初始化函数,接受参数 unet、dim、depth、max_channels、max_positions、max_values、image_size、block_size、dim_head、heads、ff_mult 和 ignore_index
    def __init__(
        self,
        *,
        unet: Unet,
        dim,
        depth,
        max_channels,
        max_positions,
        max_values,
        image_size,
        block_size = 8,
        dim_head = 32,
        heads = 8,
        ff_mult = 4.,
        ignore_index = -100
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化 UNet 模型
        self.unet = unet

        # 初始化起始标记
        self.start_token = nn.Parameter(torch.randn(dim))

        # 初始化块位置嵌入
        self.block_pos_emb = nn.Parameter(torch.randn(2, (image_size // block_size), dim))

        # 初始化通道嵌入
        self.channels = nn.Embedding(max_channels, dim)
        # 初始化位置嵌入
        self.positions = nn.Embedding(max_positions, dim)
        # 初始化值嵌入
        self.values = nn.Embedding(max_values, dim)

        # 初始化后处理层的 LayerNorm
        self.postemb_norm = nn.LayerNorm(dim) # 在 Bloom 和 YaLM 中为了稳定性而完成

        # 初始化层列表
        self.layers = nn.ModuleList([])

        # 循环创建深度个层
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, dim_head = dim_head, heads = heads, causal = True),
                Attention(dim, dim_head = dim_head, heads = heads, norm_context = True),
                FeedForward(dim, mult = ff_mult)
            ]))

        # 初始化最终层的 LayerNorm
        self.final_norm = nn.LayerNorm(dim)

        # 为最终预测给通道和位置提供单独的嵌入

        # 初始化轴向通道嵌入
        self.axial_channels = nn.Embedding(max_channels, dim)
        # 初始化轴向位置嵌入
        self.axial_positions = nn.Embedding(max_positions, dim)

        # 初始化轴向注意力机制
        self.axial_attn = Attention(dim, dim_head = dim_head,  heads = heads, causal = True)
        # 初始化轴向前馈网络
        self.axial_ff = FeedForward(dim, mult = ff_mult)

        # 初始化轴向最终层的 LayerNorm
        self.axial_final_norm = nn.LayerNorm(dim)

        # 投影到逻辑回归

        # 线性变换到通道的逻辑回归
        self.to_channel_logits = nn.Linear(dim, max_channels)
        # 线性变换到位置的逻辑回归
        self.to_position_logits = nn.Linear(dim, max_positions)
        # 线性变换到值的逻辑回归
        self.to_value_logits = nn.Linear(dim, max_values)

        # 设置忽略索引
        self.ignore_index = ignore_index

    # 获取块位置嵌入
    def get_block_pos_emb(self):
        block_pos_emb_h, block_pos_emb_w = self.block_pos_emb.unbind(dim = 0)
        block_pos_emb = rearrange(block_pos_emb_h, 'h d -> h 1 d') + rearrange(block_pos_emb_w, 'w d -> 1 w d')
        return rearrange(block_pos_emb, '... d -> (...) d')

    # 前向传播���数
    def forward(
        self,
        x,
        context_frames,
        return_loss = False
        ):
        # 断言输入张量 x 的最后一个维度为 3
        assert x.shape[-1] == 3

        # 使用上下文帧生成编码
        encoded = self.unet(context_frames)

        # 获取批次大小
        batch = x.shape[0]

        # 将输入张量 x 拆分为通道、位置和数值
        channels, positions, values = x.unbind(dim=-1)

        # 获取通道嵌入
        channel_emb = self.channels(channels)
        # 获取位置嵌入
        position_emb = self.positions(positions)
        # 获取数值嵌入
        value_emb = self.values(values)

        # 将通道、位置和数值嵌入相加得到总嵌入
        embed = channel_emb + position_emb + value_emb

        # 在嵌入前添加起始标记
        start_token = repeat(self.start_token, 'd -> b 1 d', b=batch)
        embed = torch.cat((start_token, embed), dim=1)

        # 如果需要返回损失,则截取嵌入的最后一个元素
        if return_loss:
            embed = embed[:, :-1]

        # 对嵌入进行后处理归一化
        embed = self.postemb_norm(embed)

        # 注意力层 + 交叉注意力层
        for attn, cross_attn, ff in self.layers:
            embed = attn(embed) + embed
            embed = cross_attn(embed, encoded) + embed
            embed = ff(embed) + embed

        # 对最终嵌入进行归一化
        embed = self.final_norm(embed)

        # 进行轴向注意力,从通道 + 位置 + 数值的总嵌入到下一个通道 -> 下一个位置
        axial_channels_emb = self.axial_channels(channels)
        axial_positions_emb = self.axial_positions(positions)

        # 将嵌入与轴向嵌入堆叠
        embed = torch.stack((embed, axial_channels_emb, axial_positions_emb), dim=-2)

        # 重新排列嵌入
        embed = rearrange(embed, 'b m n d -> (b m) n d')

        # 轴向注意力层
        embed = self.axial_attn(embed) + embed
        embed = self.axial_ff(embed) + embed

        # 对轴向最终嵌入进行归一化
        embed = self.axial_final_norm(embed)

        # 重新排列嵌入
        embed = rearrange(embed, '(b m) n d -> b m n d', b=batch)

        # 分离通道、位置和数值嵌入
        pred_channel_embed, pred_position_embed, pred_value_embed = embed.unbind(dim=-2)

        # 转换为 logits

        channel_logits = self.to_channel_logits(pred_channel_embed)
        position_logits = self.to_position_logits(pred_position_embed)
        value_logits = self.to_value_logits(pred_value_embed)

        # 如果不需要返回损失,则返回通道 logits、位置 logits 和���值 logits
        if not return_loss:
            return channel_logits, position_logits, value_logits

        # 重新排列 logits
        channel_logits, position_logits, value_logits = map(lambda t: rearrange(t, 'b n c -> b c n'), (channel_logits, position_logits, value_logits))

        # 交叉熵损失函数
        ce = partial(F.cross_entropy, ignore_index=self.ignore_index)

        # 计算通道、位置和数值的损失
        channel_loss = ce(channel_logits, channels)
        position_loss = ce(position_logits, positions)
        value_loss = ce(value_logits, values)

        # 返回平均损失
        return (channel_loss + position_loss + value_loss) / 3

.\lucidrains\transframer-pytorch\transframer_pytorch\__init__.py

# 从 transframer_pytorch.transframer_pytorch 模块中导入 Transframer 和 Unet 类
from transframer_pytorch.transframer_pytorch import Transframer, Unet

TransGanFormer (wip)

Implementation of TransGanFormer, an all-attention GAN that combines the finding from the recent GansFormer and TransGan paper. It will also contain a bunch of tricks I have picked up building transformers and GANs for the last year or so, including efficient linear attention and pixel level attention.

Install

$ pip install transganformer

Usage

$ transganformer --data ./path/to/data

Citations

@misc{jiang2021transgan,
    title   = {TransGAN: Two Transformers Can Make One Strong GAN}, 
    author  = {Yifan Jiang and Shiyu Chang and Zhangyang Wang},
    year    = {2021},
    eprint  = {2102.07074},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{hudson2021generative,
    title   = {Generative Adversarial Transformers}, 
    author  = {Drew A. Hudson and C. Lawrence Zitnick},
    year    = {2021},
    eprint  = {2103.01209},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\transganformer\setup.py

# 导入 sys 模块
import sys
# 从 setuptools 模块中导入 setup 和 find_packages 函数
from setuptools import setup, find_packages

# 将 'transganformer' 目录添加到 sys.path 的最前面
sys.path[0:0] = ['transganformer']
# 从 version 模块中导入 __version__ 变量
from version import __version__

# 设置包的元数据
setup(
  # 包名为 'transganformer'
  name = 'transganformer',
  # 查找所有包
  packages = find_packages(),
  # 设置入口点,命令行脚本为 'transganformer'
  entry_points={
    'console_scripts': [
      'transganformer = transganformer.cli:main',
    ],
  },
  # 设置版本号为导入的 __version__ 变量
  version = __version__,
  # 设置许可证为 MIT
  license='MIT',
  # 设置描述为 'TransGanFormer'
  description = 'TransGanFormer',
  # 设置作者为 'Phil Wang'
  author = 'Phil Wang',
  # 设置作者邮箱为 'lucidrains@gmail.com'
  author_email = 'lucidrains@gmail.com',
  # 设置项目 URL 为 'https://github.com/lucidrains/transganformer'
  url = 'https://github.com/lucidrains/transganformer',
  # 设置关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'generative adversarial networks',
    'transformers',
    'attention-mechanism'
  ],
  # 设置依赖包列表
  install_requires=[
    'einops>=0.3',
    'fire',
    'kornia',
    'numpy',
    'pillow',
    'retry',
    'torch>=1.6',
    'torchvision',
    'tqdm'
  ],
  # 设置分类列表
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\transganformer\transganformer\cli.py

# 导入所需的库
import os
import fire
import random
from retry.api import retry_call
from tqdm import tqdm
from datetime import datetime
from functools import wraps
from transganformer import Trainer, NanException

import torch
import torch.multiprocessing as mp
import torch.distributed as dist

import numpy as np

# 检查值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 将元素转换为列表
def cast_list(el):
    return el if isinstance(el, list) else [el]

# 生成带时间戳的文件名
def timestamped_filename(prefix = 'generated-'):
    now = datetime.now()
    timestamp = now.strftime("%m-%d-%Y_%H-%M-%S")
    return f'{prefix}{timestamp}'

# 设置随机种子
def set_seed(seed):
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

# 运行训练过程
def run_training(rank, world_size, model_args, data, load_from, new, num_train_steps, name, seed):
    is_main = rank == 0
    is_ddp = world_size > 1

    if is_ddp:
        set_seed(seed)
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        dist.init_process_group('nccl', rank=rank, world_size=world_size)

        print(f"{rank + 1}/{world_size} process initialized.")

    model_args.update(
        is_ddp = is_ddp,
        rank = rank,
        world_size = world_size
    )

    model = Trainer(**model_args)

    if not new:
        model.load(load_from)
    else:
        model.clear()

    model.set_data_src(data)

    for _ in tqdm(range(num_train_steps - model.steps), initial = model.steps, total = num_train_steps, mininterval=10., desc=f'{name}<{data}>'):
        retry_call(model.train, tries=3, exceptions=NanException)
        if is_main and _ % 50 == 0:
            model.print_log()

    model.save(model.checkpoint_num)

    if is_ddp:
        dist.destroy_process_group()

# 从文件夹中训练模型
def train_from_folder(
    data = './data',
    results_dir = './results',
    models_dir = './models',
    name = 'default',
    new = False,
    load_from = -1,
    image_size = 32,
    fmap_max = 512,
    transparent = False,
    greyscale = False,
    batch_size = 10,
    gradient_accumulate_every = 4,
    num_train_steps = 150000,
    learning_rate = 2e-4,
    save_every = 1000,
    evaluate_every = 1000,
    generate = False,
    generate_types = ['default', 'ema'],
    generate_interpolation = False,
    aug_test = False,
    aug_prob=None,
    aug_types=['cutout', 'translation'],
    dataset_aug_prob=0.,
    interpolation_num_steps = 100,
    save_frames = False,
    num_image_tiles = None,
    num_workers = None,
    multi_gpus = False,
    calculate_fid_every = None,
    calculate_fid_num_images = 12800,
    clear_fid_cache = False,
    seed = 42,
    amp = False,
    show_progress = False,
):
    num_image_tiles = default(num_image_tiles, 4 if image_size > 512 else 8)

    model_args = dict(
        name = name,
        results_dir = results_dir,
        models_dir = models_dir,
        batch_size = batch_size,
        gradient_accumulate_every = gradient_accumulate_every,
        image_size = image_size,
        num_image_tiles = num_image_tiles,
        num_workers = num_workers,
        fmap_max = fmap_max,
        transparent = transparent,
        greyscale = greyscale,
        lr = learning_rate,
        save_every = save_every,
        evaluate_every = evaluate_every,
        aug_prob = aug_prob,
        aug_types = cast_list(aug_types),
        dataset_aug_prob = dataset_aug_prob,
        calculate_fid_every = calculate_fid_every,
        calculate_fid_num_images = calculate_fid_num_images,
        clear_fid_cache = clear_fid_cache,
        amp = amp
    )
    # 如果需要生成样本图片
    if generate:
        # 创建训练器对象
        model = Trainer(**model_args)
        # 从指定路径加载模型
        model.load(load_from)
        # 生成带时间戳的文件名
        samples_name = timestamped_filename()
        # 获取模型的检查点编号
        checkpoint = model.checkpoint_num
        # 生成样本图片并返回结果目录
        dir_result = model.generate(samples_name, num_image_tiles, checkpoint, generate_types)
        # 打印生成的样本图片目录
        print(f'sample images generated at {dir_result}')
        return

    # 如果需要生成插值图片
    if generate_interpolation:
        # 创建训练器对象
        model = Trainer(**model_args)
        # 从指定路径加载模型
        model.load(load_from)
        # 生成带时间戳的文件名
        samples_name = timestamped_filename()
        # 生成插值图片
        model.generate_interpolation(samples_name, num_image_tiles, num_steps = interpolation_num_steps, save_frames = save_frames)
        # 打印生成的插值图片目录
        print(f'interpolation generated at {results_dir}/{name}/{samples_name}')
        return

    # 如果需要展示训练进度
    if show_progress:
        # 创建训练器对象
        model = Trainer(**model_args)
        # 展示训练进度
        model.show_progress(num_images=num_image_tiles, types=generate_types)
        return

    # 获取当前可用的 GPU 数量
    world_size = torch.cuda.device_count()

    # 如果只有一个 GPU 或者不使用多 GPU 训练
    if world_size == 1 or not multi_gpus:
        # 单 GPU 训练
        run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed)
        return

    # 使用多 GPU 训练
    mp.spawn(run_training,
        args=(world_size, model_args, data, load_from, new, num_train_steps, name, seed),
        nprocs=world_size,
        join=True)
# 定义主函数
def main():
    # 使用 Fire 库将 train_from_folder 函数转换为命令行接口
    fire.Fire(train_from_folder)

.\lucidrains\transganformer\transganformer\diff_augment.py

# 导入random和torch模块
import random
import torch
import torch.nn.functional as F

# 定义数据增强函数DiffAugment,接受输入x和增强类型types
def DiffAugment(x, types=[]):
    # 遍历增强类型
    for p in types:
        # 遍历对应增强函数列表
        for f in AUGMENT_FNS[p]:
            # 对输入x应用增强函数f
            x = f(x)
    # 返回增强后的数据x
    return x.contiguous()

# 定义随机亮度增强函数
def rand_brightness(x):
    # 生成随机亮度增强值,应用到输入x上
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
    return x

# 定义随机饱和度增强函数
def rand_saturation(x):
    # 计算输入x的均值,对每个像素应用随机饱和度增强
    x_mean = x.mean(dim=1, keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
    return x

# 定义随机对比度增强函数
def rand_contrast(x):
    # 计算输入x的均值,对每个像素应用随机对比度增强
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
    return x

# 定义随机平移增强函数
def rand_translation(x, ratio=0.125):
    # 计算平移范围,生成随机平移值,对输入x进行平移操作
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
    return x

# 定义随机偏移增强函数
def rand_offset(x, ratio=1, ratio_h=1, ratio_v=1):
    # 计算偏移范围,生成随机偏移值,对输入x进行偏移操作
    w, h = x.size(2), x.size(3)

    imgs = []
    for img in x.unbind(dim = 0):
        max_h = int(w * ratio * ratio_h)
        max_v = int(h * ratio * ratio_v)

        value_h = random.randint(0, max_h) * 2 - max_h
        value_v = random.randint(0, max_v) * 2 - max_v

        if abs(value_h) > 0:
            img = torch.roll(img, value_h, 2)

        if abs(value_v) > 0:
            img = torch.roll(img, value_v, 1)

        imgs.append(img)

    return torch.stack(imgs)

# 定义水平偏移增强函数
def rand_offset_h(x, ratio=1):
    return rand_offset(x, ratio=1, ratio_h=ratio, ratio_v=0)

# 定义垂直偏移增强函数
def rand_offset_v(x, ratio=1):
    return rand_offset(x, ratio=1, ratio_h=0, ratio_v=ratio)

# 定义随机遮挡增强函数
def rand_cutout(x, ratio=0.5):
    # 计算遮挡尺寸,生成随机遮挡位置,对输入x进行遮挡操作
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x

# 定义增强函数字典,包含不同类型的增强函数列表
AUGMENT_FNS = {
    'color':        [rand_brightness, rand_saturation, rand_contrast],
    'offset':       [rand_offset],
    'offset_h':     [rand_offset_h],
    'offset_v':     [rand_offset_v],
    'translation':  [rand_translation],
    'cutout':       [rand_cutout],
}

.\lucidrains\transganformer\transganformer\transganformer.py

# 导入所需的库
import os
import json
import multiprocessing
from random import random
import math
from math import log2, floor, sqrt, log, pi
from functools import partial
from contextlib import contextmanager, ExitStack
from pathlib import Path
from shutil import rmtree

import torch
from torch.cuda.amp import autocast, GradScaler
from torch.optim import Adam
from torch import nn, einsum
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.autograd import grad as torch_grad
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

from PIL import Image
import torchvision
from torchvision import transforms
from kornia import filter2D

from transganformer.diff_augment import DiffAugment
from transganformer.version import __version__

from tqdm import tqdm
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange

# 断言CUDA是否可用
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'

# 常量定义
NUM_CORES = multiprocessing.cpu_count()
EXTS = ['jpg', 'jpeg', 'png']

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 空上下文管理器
@contextmanager
def null_context():
    yield

# 合并多个上下文管理器
def combine_contexts(contexts):
    @contextmanager
    def multi_contexts():
        with ExitStack() as stack:
            yield [stack.enter_context(ctx()) for ctx in contexts]
    return multi_contexts

# 判断是否为2的幂
def is_power_of_two(val):
    return log2(val).is_integer()

# 设置模型参数是否需要梯度
def set_requires_grad(model, bool):
    for p in model.parameters():
        p.requires_grad = bool

# 无限循环生成器
def cycle(iterable):
    while True:
        for i in iterable:
            yield i

# 如果值为NaN,则抛出异常
def raise_if_nan(t):
    if torch.isnan(t):
        raise NanException

# 梯度累积上下文管理器
def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps):
    if is_ddp:
        num_no_syncs = gradient_accumulate_every - 1
        head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs
        tail = [null_context]
        contexts =  head + tail
    else:
        contexts = [null_context] * gradient_accumulate_every

    for context in contexts:
        with context():
            yield

# 分块评估
def evaluate_in_chunks(max_batch_size, model, *args):
    split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
    chunked_outputs = [model(*i) for i in split_args]
    if len(chunked_outputs) == 1:
        return chunked_outputs[0]
    return torch.cat(chunked_outputs, dim=0)

# 球面线性插值
def slerp(val, low, high):
    low_norm = low / torch.norm(low, dim=1, keepdim=True)
    high_norm = high / torch.norm(high, dim=1, keepdim=True)
    omega = torch.acos((low_norm * high_norm).sum(1))
    so = torch.sin(omega)
    res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
    return res

# 安全除法
def safe_div(n, d):
    try:
        res = n / d
    except ZeroDivisionError:
        prefix = '' if int(n >= 0) else '-'
        res = float(f'{prefix}inf')
    return res

# 辅助类

# NaN异常类
class NanException(Exception):
    pass

# 指数移动平均类
class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
    def update_average(self, old, new):
        if not exists(old):
            return new
        return old * self.beta + (1 - self.beta) * new

# 随机应用类
class RandomApply(nn.Module):
    def __init__(self, prob, fn, fn_else = lambda x: x):
        super().__init__()
        self.fn = fn
        self.fn_else = fn_else
        self.prob = prob
    def forward(self, x):
        fn = self.fn if random() < self.prob else self.fn_else
        return fn(x)

# 残差连接类
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    # 定义一个前向传播函数,接受输入 x 和其他关键字参数
    def forward(self, x, **kwargs):
        # 调用 self.fn 函数进行前向传播,得到输出 out
        out = self.fn(x, **kwargs)

        # 如果输出是一个元组
        if isinstance(out, tuple):
            # 将元组拆分为 out 和 latent 两部分
            out, latent = out
            # 将输入 x 和 out 相加,得到 ret
            ret = (out + x, latent)
            # 返回 ret
            return ret

        # 如果输出不是元组,则将输入 x 和输出 out 相加,返回结果
        return x + out
class SumBranches(nn.Module):
    # 定义一个类,用于将多个分支的输出求和
    def __init__(self, branches):
        super().__init__()
        self.branches = nn.ModuleList(branches)

    def forward(self, x):
        # 对每个分支的输出进行映射并求和
        return sum(map(lambda fn: fn(x), self.branches))

# attention and transformer modules

class ChanNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        # 计算输入张量 x 的标准差和均值,进行归一化处理
        std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (std + self.eps) * self.g + self.b

class PreNorm(nn.Module):
    def __init__(self, dim, fn, dim_context = None):
        super().__init__()
        self.norm = ChanNorm(dim)
        self.norm_context = ChanNorm(dim_context) if exists(dim_context) else None
        self.fn = fn

    def forward(self, x, **kwargs):
        # 对输入张量 x 进行归一化处理
        x = self.norm(x)

        if exists(self.norm_context):
            context = kwargs.pop('context')
            context = self.norm_context(context)
            kwargs.update(context = context)

        return self.fn(x, **kwargs)

class DepthWiseConv2d(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size, padding = 0, stride = 1, bias = True):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
            nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
        )
    def forward(self, x):
        return self.net(x)

def FeedForward(dim, mult = 4, kernel_size = 3, bn = False):
    padding = kernel_size // 2
    return nn.Sequential(
        nn.Conv2d(dim, dim * mult * 2, 1),
        nn.GLU(dim = 1),
        nn.BatchNorm2d(dim * mult) if bn else nn.Identity(),
        DepthWiseConv2d(dim * mult, dim * mult * 2, kernel_size, padding = padding),
        nn.GLU(dim = 1),
        nn.Conv2d(dim * mult, dim, 1)
    )

# sinusoidal embedding

class FixedPositionalEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        dim //= 2
        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, x):
        h = torch.linspace(-1., 1., x.shape[-2], device = x.device).type_as(self.inv_freq)
        w = torch.linspace(-1., 1., x.shape[-1], device = x.device).type_as(self.inv_freq)
        sinu_inp_h = torch.einsum('i , j -> i j', h, self.inv_freq)
        sinu_inp_w = torch.einsum('i , j -> i j', w, self.inv_freq)
        sinu_inp_h = repeat(sinu_inp_h, 'h c -> () c h w', w = x.shape[-1])
        sinu_inp_w = repeat(sinu_inp_w, 'w c -> () c h w', h = x.shape[-2])
        sinu_inp = torch.cat((sinu_inp_w, sinu_inp_h), dim = 1)
        emb = torch.cat((sinu_inp.sin(), sinu_inp.cos()), dim = 1)
        return emb

# classes

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        fmap_size = None,
        dim_out = None,
        kv_dim = None,
        heads = 8,
        dim_head = 64,
        q_kernel_size = 1,
        kv_kernel_size = 3,
        out_kernel_size = 1,
        q_stride = 1,
        include_self = False,
        downsample = False,
        downsample_kv = 1,
        bn = False,
        latent_dim = None
        ):
        # 调用父类的构造函数
        super().__init__()
        # 创建固定位置嵌入对象
        self.sinu_emb = FixedPositionalEmbedding(dim)

        # 计算内部维度
        inner_dim = dim_head *  heads
        # 设置键值维度,默认为 dim
        kv_dim = default(kv_dim, dim)
        # 设置输出维度,默认为 dim
        dim_out = default(dim_out, dim)

        # 设置头数和缩放因子
        self.heads = heads
        self.scale = dim_head ** -0.5

        # 计算填充值
        q_padding = q_kernel_size // 2
        kv_padding = kv_kernel_size // 2
        out_padding = out_kernel_size // 2

        # 设置查询卷积参数
        q_conv_params = (1, 1, 0)

        # 创建查询卷积层
        self.to_q = nn.Conv2d(dim, inner_dim, *q_conv_params, bias = False)

        # 根据下采样因子设置键值卷积参数
        if downsample_kv == 1:
            kv_conv_params = (3, 1, 1)
        elif downsample_kv == 2:
            kv_conv_params = (3, 2, 1)
        elif downsample_kv == 4:
            kv_conv_params = (7, 4, 3)
        else:
            raise ValueError(f'invalid downsample factor for key / values {downsample_kv}')

        # 创建键卷积层和值卷积层
        self.to_k = nn.Conv2d(kv_dim, inner_dim, *kv_conv_params, bias = False)
        self.to_v = nn.Conv2d(kv_dim, inner_dim, *kv_conv_params, bias = False)

        # 设置是否使用批归一化
        self.bn = bn
        if self.bn:
            self.q_bn = nn.BatchNorm2d(inner_dim) if bn else nn.Identity()
            self.k_bn = nn.BatchNorm2d(inner_dim) if bn else nn.Identity()
            self.v_bn = nn.BatchNorm2d(inner_dim) if bn else nn.Identity()

        # 检查是否存在潜在维度
        self.has_latents = exists(latent_dim)
        if self.has_latents:
            # 创建潜在维度的通道归一化层和潜在维度到查询、键、值的卷积层
            self.latent_norm = ChanNorm(latent_dim)
            self.latents_to_qkv = nn.Conv2d(latent_dim, inner_dim * 3, 1, bias = False)

            # 创建潜在维度到输出的卷积层序列
            self.latents_to_out = nn.Sequential(
                nn.Conv2d(inner_dim, latent_dim * 2, 1),
                nn.GLU(dim = 1),
                nn.BatchNorm2d(latent_dim) if bn else nn.Identity()
            )

        # 设置是否包含自身
        self.include_self = include_self
        if include_self:
            # 创建自身到自身的键卷积层和值卷积层
            self.to_self_k = nn.Conv2d(dim, inner_dim, *kv_conv_params, bias = False)
            self.to_self_v = nn.Conv2d(dim, inner_dim, *kv_conv_params, bias = False)

        # 创建混合头部后的参数
        self.mix_heads_post = nn.Parameter(torch.randn(heads, heads))

        # 根据是否下采样设置输出卷积参数
        out_conv_params = (3, 2, 1) if downsample else q_conv_params

        # 创建输出卷积层序列
        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim_out * 2, *out_conv_params),
            nn.GLU(dim = 1),
            nn.BatchNorm2d(dim_out) if bn else nn.Identity()
        )

        # 设置特征图大小和旋转嵌入
        self.fmap_size = fmap_size
        self.pos_emb = RotaryEmbedding(dim_head, downsample_keys = downsample_kv)
    # 定义前向传播函数,接受输入 x,潜在变量 latents,默认上下文 context,是否包含自身 include_self
    def forward(self, x, latents = None, context = None, include_self = False):
        # 断言检查输入 x 的最后一个维度是否与指定的 fmap_size 相等
        assert not exists(self.fmap_size) or x.shape[-1] == self.fmap_size, 'fmap size must equal the given shape'

        # 获取输入 x 的形状信息
        b, n, _, y, h, device = *x.shape, self.heads, x.device

        # 检查是否存在上下文信息,如果不存在,则使用输入 x 作为上下文
        has_context = exists(context)
        context = default(context, x)

        # 初始化查询、键、值的输入
        q_inp = x
        k_inp = context
        v_inp = context

        # 如果不存在上下文信息,则添加正弦嵌入
        if not has_context:
            sinu_emb = self.sinu_emb(context)
            q_inp += sinu_emb
            k_inp += sinu_emb

        # 将查询、键、值通过对应的线性变换层
        q, k, v = (self.to_q(q_inp), self.to_k(k_inp), self.to_v(v_inp))

        # 如果启用了批归一化,则对查询、键、值进行批归一化
        if self.bn:
            q = self.q_bn(q)
            k = self.k_bn(k)
            v = self.v_bn(v)

        # 获取查询的输出高度和宽度
        out_h, out_w = q.shape[-2:]

        # 定义函数将查询、键、值按头数拆分
        split_head = lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = h)

        # 对查询、键、值按头数拆分
        q, k, v = map(split_head, (q, k, v))

        # 如果不存在上下文信息,则对查询、键添加位置嵌入
        if not has_context:
            q, k = self.pos_emb(q, k)

        # 如果包含自身信息,则将自身信息添加到键和值中
        if self.include_self:
            kx = self.to_self_k(x)
            vx = self.to_self_v(x)
            kx, vx = map(split_head, (kx, vx))

            k = torch.cat((kx, k), dim = -2)
            v = torch.cat((vx, v), dim = -2)

        # 如果存在潜在变量,则将潜在变量信息添加到查询、键、值中
        if self.has_latents:
            assert exists(latents), 'latents must be passed in'
            latents = self.latent_norm(latents)
            lq, lk, lv = self.latents_to_qkv(latents).chunk(3, dim = 1)
            lq, lk, lv = map(split_head, (lq, lk, lv))

            latent_shape = lq.shape
            num_latents = lq.shape[-2]

            q = torch.cat((lq, q), dim = -2)
            k = torch.cat((lk, k), dim = -2)
            v = torch.cat((lv, v), dim = -2)

        # 计算点积注意力得分
        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        # 对注意力得分进行 softmax 操作
        attn = dots.softmax(dim = -1)
        attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post)

        # 计算输出
        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)

        # 如果存在潜在变量,则将潜在变量信息分离出来
        if self.has_latents:
            lout, out = out[..., :num_latents, :], out[..., num_latents:, :]
            lout = rearrange(lout, 'b h (x y) d -> b (h d) x y', h = h, x = latents.shape[-2], y = latents.shape[-1])
            lout = self.latents_to_out(lout)

        # 重组输出形状
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, x = out_h, y = out_w)
        out = self.to_out(out)

        # 如果存在潜在变量,则返回输出和潜在变量输出
        if self.has_latents:
            return out, lout

        # 否则只返回输出
        return out
# dataset

# 将图像转换为指定类型
def convert_image_to(img_type, image):
    # 如果图像模式不是指定类型,则进行转换
    if image.mode != img_type:
        return image.convert(img_type)
    return image

# 定义一个身份函数类
class identity(object):
    def __call__(self, tensor):
        return tensor

# 扩展灰度图像类
class expand_greyscale(object):
    def __init__(self, transparent):
        self.transparent = transparent

    def __call__(self, tensor):
        # 获取图像通道数
        channels = tensor.shape[0]
        num_target_channels = 4 if self.transparent else 3

        # 如果通道数与目标通道数相同,则返回原图像
        if channels == num_target_channels:
            return tensor

        alpha = None
        if channels == 1:
            color = tensor.expand(3, -1, -1)
        elif channels == 2:
            color = tensor[:1].expand(3, -1, -1)
            alpha = tensor[1:]
        else:
            raise Exception(f'image with invalid number of channels given {channels}')

        # 如果不存在 alpha 通道且需要透明度,则创建全白的 alpha 通道
        if not exists(alpha) and self.transparent:
            alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device)

        return color if not self.transparent else torch.cat((color, alpha))

# 调整图像大小至最小尺寸
def resize_to_minimum_size(min_size, image):
    if max(*image.size) < min_size:
        return torchvision.transforms.functional.resize(image, min_size)
    return image

# 图像数据集类
class ImageDataset(Dataset):
    def __init__(
        self,
        folder,
        image_size,
        transparent = False,
        greyscale = False,
        aug_prob = 0.
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.paths = [p for ext in EXTS for p in Path(f'{folder}').glob(f'**/*.{ext}')]
        assert len(self.paths) > 0, f'No images were found in {folder} for training'

        if transparent:
            num_channels = 4
            pillow_mode = 'RGBA'
            expand_fn = expand_greyscale(transparent)
        elif greyscale:
            num_channels = 1
            pillow_mode = 'L'
            expand_fn = identity()
        else:
            num_channels = 3
            pillow_mode = 'RGB'
            expand_fn = expand_greyscale(transparent)

        convert_image_fn = partial(convert_image_to, pillow_mode)

        self.transform = transforms.Compose([
            transforms.Lambda(convert_image_fn),
            transforms.Lambda(partial(resize_to_minimum_size, image_size)),
            transforms.Resize(image_size),
            RandomApply(aug_prob, transforms.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)), transforms.CenterCrop(image_size)),
            transforms.ToTensor(),
            transforms.Lambda(expand_fn)
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

# augmentations

# 随机水平翻转函数
def random_hflip(tensor, prob):
    if prob > random():
        return tensor
    return torch.flip(tensor, dims=(3,))

# 增强包装类
class AugWrapper(nn.Module):
    def __init__(self, D, image_size):
        super().__init__()
        self.D = D

    def forward(self, images, prob = 0., types = [], detach = False, **kwargs):
        context = torch.no_grad if detach else null_context

        with context():
            if random() < prob:
                images = random_hflip(images, prob=0.5)
                images = DiffAugment(images, types=types)

        return self.D(images, **kwargs)

# modifiable global variables

# 上采样函数
def upsample(scale_factor = 2):
    return nn.Upsample(scale_factor = scale_factor)

# activation

# Leaky ReLU 激活函数
def leaky_relu(p = 0.1):
    return nn.LeakyReLU(p)

# rotary positional embedding helpers

# 每两个元素旋转函数
def rotate_every_two(x):
    x = rearrange(x, '... (d j) -> ... d j', j = 2)
    x1, x2 = x.unbind(dim = -1)
    x = torch.stack((-x2, x1), dim = -1)
    return rearrange(x, '... d j -> ... (d j)')

# 获取正弦余弦值函数
def get_sin_cos(seq):
    n = seq.shape[0]
    x_sinu = repeat(seq, 'i d -> i j d', j = n)
    y_sinu = repeat(seq, 'j d -> i j d', i = n)

    sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1)
    # 将 x_sinu 和 y_sinu 的余弦值按照最后一个维度连接起来
    cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1)

    # 对 sin 和 cos 进行重排列,将最后两个维度合并到一起
    sin, cos = map(lambda t: rearrange(t, 'i j d -> (i j) d'), (sin, cos))
    # 对 sin 和 cos 进行重复,扩展维度
    sin, cos = map(lambda t: repeat(t, 'n d -> () () n (d j)', j = 2), (sin, cos))
    # 返回重排列后的 sin 和 cos
    return sin, cos
# positional encoding

# 定义旋转嵌入类
class RotaryEmbedding(nn.Module):
    # 初始化函数
    def __init__(self, dim, downsample_keys = 1):
        super().__init__()
        self.dim = dim
        self.downsample_keys = downsample_keys

    # 前向传播函数
    def forward(self, q, k):
        device, dtype, n = q.device, q.dtype, int(sqrt(q.shape[-2]))

        # 生成等间距序列
        seq = torch.linspace(-1., 1., steps = n, device = device)
        seq = seq.unsqueeze(-1)

        # 生成不同尺度的旋转角度
        scales = torch.logspace(0., log(10 / 2) / log(2), self.dim // 4, base = 2, device = device, dtype = dtype)
        scales = scales[(*((None,) * (len(seq.shape) - 1)), Ellipsis]

        seq = seq * scales * pi

        x = seq
        y = seq

        # 对 y 进行降采样
        y = reduce(y, '(j n) c -> j c', 'mean', n = self.downsample_keys)

        # 获取正弦和余弦值
        q_sin, q_cos = get_sin_cos(x)
        k_sin, k_cos = get_sin_cos(y)
        q = (q * q_cos) + (rotate_every_two(q) * q_sin)
        k = (k * k_cos) + (rotate_every_two(k) * k_sin)
        return q, k

# mapping network

# 定义等权重线性变换类
class EqualLinear(nn.Module):
    # 初始化函数
    def __init__(self, in_dim, out_dim, lr_mul = 1, bias = True):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_dim))

        self.lr_mul = lr_mul

    # 前向传播函数
    def forward(self, input):
        return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)

# 定义映射网络类
class MappingNetwork(nn.Module):
    # 初始化函数
    def __init__(self, dim, depth, lr_mul = 0.1):
        super().__init__()

        layers = []
        for i in range(depth):
            layers.extend([EqualLinear(dim, dim, lr_mul), leaky_relu()])

        self.net = nn.Sequential(
            *layers,
            nn.Linear(dim, dim * 4)
        )

    # 前向传播函数
    def forward(self, x):
        x = F.normalize(x, dim=1)
        x = self.net(x)
        return rearrange(x, 'b (c h w) -> b c h w', h = 2, w = 2)

# generative adversarial network

# 定义生成器类
class Generator(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        image_size,
        latent_dim = 256,
        fmap_max = 512,
        init_channel = 3,
        mapping_network_depth = 4
    ):
        super().__init__()
        assert is_power_of_two(image_size), 'image size must be a power of 2'
        num_layers = int(log2(image_size)) - 1
        
        self.mapping = MappingNetwork(latent_dim, mapping_network_depth)
        self.initial_block = nn.Parameter(torch.randn((latent_dim, 4, 4)))

        self.layers = nn.ModuleList([])

        fmap_size = 4
        chan = latent_dim
        min_chan = 8

        for ind in range(num_layers):
            is_last = ind == (num_layers - 1)

            downsample_factor = int(2 ** max(log2(fmap_size) - log2(32), 0))
            attn_class = partial(Attention, bn = True, fmap_size = fmap_size, downsample_kv = downsample_factor)

            if not is_last:
                chan_out = max(min_chan, chan // 4)

                upsample = nn.Sequential(
                    attn_class(dim = chan, dim_head = chan, heads = 1, dim_out = chan_out * 4),
                    nn.PixelShuffle(2)
                )

            else:
                upsample = nn.Identity()

            self.layers.append(nn.ModuleList([
                Residual(PreNorm(chan, attn_class(dim = chan, latent_dim = latent_dim))),
                Residual(FeedForward(chan, bn = True, kernel_size = (3 if image_size > 4 else 1))),
                upsample,
            ]))

            chan = chan_out
            fmap_size *= 2

        self.final_attn = Residual(PreNorm(chan, attn_class(chan, latent_dim = latent_dim)))

        self.to_img = nn.Sequential(
            Residual(FeedForward(chan_out, bn = True)),
            nn.Conv2d(chan, init_channel, 1)
        )
    # 定义一个前向传播函数,接受输入 x
    def forward(self, x):
        # 获取输入 x 的 batch 大小
        b = x.shape[0]

        # 将输入 x 映射到潜在空间
        latents = self.mapping(x)

        # 重复初始块的特征图,使其与 batch 大小相匹配
        fmap = repeat(self.initial_block, 'c h w -> b c h w', b = b)

        # 遍历每个层中的注意力机制、特征提取和上采样操作
        for attn, ff, upsample in self.layers:
            # 使用注意力机制处理特征图和潜在空间
            fmap, latents_out = attn(fmap, latents = latents)
            # 更新潜在空间
            latents = latents + latents_out

            # 使用特征提取函数处理特征图
            fmap = ff(fmap)
            # 使用上采样函数对特征图进行上采样

            fmap = upsample(fmap)

        # 最终使用最终的注意力机制处理特征图和潜在空间
        fmap, _ = self.final_attn(fmap, latents = latents)
        # 将处理后的特征图转换为图像
        return self.to_img(fmap)
# 定义一个简单的解码器类,继承自 nn.Module
class SimpleDecoder(nn.Module):
    # 初始化函数,设置输入通道数、输出通道数、上采样次数等参数
    def __init__(
        self,
        *,
        chan_in,
        chan_out = 3,
        num_upsamples = 4,
    ):
        super().__init__()

        # 初始化空的层列表
        layers = nn.ModuleList([])
        # 设置最终输出通道数
        final_chan = chan_out
        # 设置初始通道数
        chans = chan_in

        # 循环创建上采样层
        for ind in range(num_upsamples):
            # 判断是否是最后一层
            last_layer = ind == (num_upsamples - 1)
            # 根据是否是最后一层确定输出通道数
            chan_out = chans if not last_layer else final_chan * 2
            # 创建包含上采样、卷积和 GLU 激活函数的层
            layer = nn.Sequential(
                upsample(),
                nn.Conv2d(chans, chan_out, 3, padding = 1),
                nn.GLU(dim = 1)
            )
            # 将层添加到层列表中
            layers.append(layer)
            # 更新通道数
            chans //= 2

        # 将所有层组合成一个网络
        self.net = nn.Sequential(*layers)

    # 前向传播函数
    def forward(self, x):
        return self.net(x)

# 定义一个鉴别器类,继承自 nn.Module
class Discriminator(nn.Module):
    # 初始化函数,设置图像大小、最大特征图数、初始通道数等参数
    def __init__(
        self,
        *,
        image_size,
        fmap_max = 256,
        init_channel = 3,
    ):
        super().__init__()
        # 断言图像大小为 2 的幂次方
        assert is_power_of_two(image_size), 'image size must be a power of 2'
        # 计算层数
        num_layers = int(log2(image_size)) - 2
        # 设置特征图维度
        fmap_dim = 64

        # 创建卷积嵌入层
        self.conv_embed = nn.Sequential(
            nn.Conv2d(init_channel, 32, kernel_size = 4, stride = 2, padding = 1),
            nn.Conv2d(32, fmap_dim, kernel_size = 3, padding = 1)
        )

        # 更新图像大小
        image_size //= 2
        # 创建横向和纵向位置嵌入参数
        self.ax_pos_emb_h = nn.Parameter(torch.randn(image_size, fmap_dim))
        self.ax_pos_emb_w = nn.Parameter(torch.randn(image_size, fmap_dim))

        # 初始化空的图层列表和特征图维度列表
        self.image_sizes = []
        self.layers = nn.ModuleList([])
        fmap_dims = []

        # 循环创建图层
        for ind in range(num_layers):
            # 更新图像大小
            image_size //= 2
            self.image_sizes.append(image_size)

            # 计算输出特征图维度
            fmap_dim_out = min(fmap_dim * 2, fmap_max)

            # 创建下采样分支
            downsample = SumBranches([
                nn.Conv2d(fmap_dim, fmap_dim_out, 3, 2, 1),
                nn.Sequential(
                    nn.AvgPool2d(2),
                    nn.Conv2d(fmap_dim, fmap_dim_out, 3, padding = 1),
                    leaky_relu()
                )
            ])

            # 计算下采样因子
            downsample_factor = 2 ** max(log2(image_size) - log2(32), 0)
            # 创建注意力类
            attn_class = partial(Attention, fmap_size = image_size, downsample_kv = downsample_factor)

            # 将下采样、残差块和前馈网络块组合成一个图层
            self.layers.append(nn.ModuleList([
                downsample,
                Residual(PreNorm(fmap_dim_out, attn_class(dim = fmap_dim_out))),
                Residual(PreNorm(fmap_dim_out, FeedForward(dim = fmap_dim_out, kernel_size = (3 if image_size > 4 else 1)))
            ]))

            # 更新特征图维度和特征图维度列表
            fmap_dim = fmap_dim_out
            fmap_dims.append(fmap_dim)

        # 创建辅助解码器
        self.aux_decoder = SimpleDecoder(chan_in = fmap_dims[-2], chan_out = init_channel, num_upsamples = num_layers)

        # 创建输出层
        self.to_logits = nn.Sequential(
            Residual(PreNorm(fmap_dim, Attention(dim = fmap_dim, fmap_size = 2))),
            Residual(PreNorm(fmap_dim, FeedForward(dim = fmap_dim, kernel_size = (3 if image_size > 64 else 1)))),
            nn.Conv2d(fmap_dim, 1, 2),
            Rearrange('b () () () -> b')
        )

    # 前向传播函数
    def forward(self, x, calc_aux_loss = False):
        x_ = x
        x = self.conv_embed(x)

        ax_pos_emb = rearrange(self.ax_pos_emb_h, 'h c -> () c h ()') + rearrange(self.ax_pos_emb_w, 'w c -> () c () w')
        x += ax_pos_emb

        fmaps = []

        for (downsample, attn, ff), image_size in zip(self.layers, self.image_sizes):
            x = downsample(x)
            x = attn(x)
            x = ff(x)

            fmaps.append(x)

        x = self.to_logits(x)

        if not calc_aux_loss:
            return x, None

        recon = self.aux_decoder(fmaps[-2])
        recon_loss = F.mse_loss(x_, recon)
        return x, recon_loss

# 定义一个 Transganformer 类,继承自 nn.Module
class Transganformer(nn.Module):
    # 初始化函数,设置潜在维度、图像大小、最大特征图数等参数
    def __init__(
        self,
        *,
        latent_dim,
        image_size,
        fmap_max = 512,
        transparent = False,
        greyscale = False,
        ttur_mult = 1.,
        lr = 2e-4,
        rank = 0,
        ddp = False
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化潜在空间维度和图像大小
        self.latent_dim = latent_dim
        self.image_size = image_size

        # 根据是否透明或灰度图像确定初始通道数
        if transparent:
            init_channel = 4
        elif greyscale:
            init_channel = 1
        else:
            init_channel = 3

        # 创建生成器参数字典
        G_kwargs = dict(
            image_size = image_size,
            latent_dim = latent_dim,
            fmap_max = fmap_max,
            init_channel = init_channel
        )

        # 初始化生成器和判别器
        self.G = Generator(**G_kwargs)
        self.D = Discriminator(
            image_size = image_size,
            fmap_max = fmap_max,
            init_channel = init_channel
        )

        # 初始化指数移动平均更新器和生成器EMA
        self.ema_updater = EMA(0.995)
        self.GE = Generator(**G_kwargs)
        set_requires_grad(self.GE, False)

        # 初始化生成器和判别器的优化器
        self.G_opt = Adam(self.G.parameters(), lr = lr, betas=(0.5, 0.9))
        self.D_opt = Adam(self.D.parameters(), lr = lr * ttur_mult, betas=(0.5, 0.9))

        # 初始化权重
        self.apply(self._init_weights)
        self.reset_parameter_averaging()

        # 将模型移至GPU
        self.cuda(rank)
        # 初始化带数据增强的判别器
        self.D_aug = AugWrapper(self.D, image_size)

    # 初始化权重函数
    def _init_weights(self, m):
        if type(m) in {nn.Conv2d, nn.Linear}:
            nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')

    # 更新EMA函数
    def EMA(self):
        def update_moving_average(ma_model, current_model):
            for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
                old_weight, up_weight = ma_params.data, current_params.data
                ma_params.data = self.ema_updater.update_average(old_weight, up_weight)

            for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
                new_buffer_value = self.ema_updater.update_average(ma_buffer, current_buffer)
                ma_buffer.copy_(new_buffer_value)

        update_moving_average(self.GE, self.G)

    # 重置参数平均函数
    def reset_parameter_averaging(self):
        self.GE.load_state_dict(self.G.state_dict())

    # 前向传播函数
    def forward(self, x):
        raise NotImplemented
# 定义 Trainer 类,用于训练模型
class Trainer():
    # 初始化函数,设置各种参数
    def __init__(
        self,
        name = 'default',
        results_dir = 'results',
        models_dir = 'models',
        base_dir = './',
        num_workers = None,
        latent_dim = 256,
        image_size = 128,
        num_image_tiles = 8,
        fmap_max = 512,
        transparent = False,
        greyscale = False,
        batch_size = 4,
        gp_weight = 10,
        gradient_accumulate_every = 1,
        lr = 2e-4,
        lr_mlp = 1.,
        ttur_mult = 1.,
        save_every = 1000,
        evaluate_every = 1000,
        aug_prob = None,
        aug_types = ['translation', 'cutout'],
        dataset_aug_prob = 0.,
        calculate_fid_every = None,
        calculate_fid_num_images = 12800,
        clear_fid_cache = False,
        is_ddp = False,
        rank = 0,
        world_size = 1,
        log = False,
        amp = False,
        *args,
        **kwargs
    ):
        # 存储传入的参数
        self.GAN_params = [args, kwargs]
        self.GAN = None

        self.name = name

        # 设置路径相关参数
        base_dir = Path(base_dir)
        self.base_dir = base_dir
        self.results_dir = base_dir / results_dir
        self.models_dir = base_dir / models_dir
        self.fid_dir = base_dir / 'fid' / name

        self.config_path = self.models_dir / name / '.config.json'

        # 检查图片大小是否为2的幂次方
        assert is_power_of_two(image_size), 'image size must be a power of 2 (32, 64, 128, 256, 512, 1024)'

        # 设置图片相关参数
        self.image_size = image_size
        self.num_image_tiles = num_image_tiles

        # 设置潜在空间维度、特征图最大值、透明度、灰度等参数
        self.latent_dim = latent_dim
        self.fmap_max = fmap_max
        self.transparent = transparent
        self.greyscale = greyscale

        # 检查透明度和灰度是否只设置了一个
        assert (int(self.transparent) + int(self.greyscale)) < 2, 'you can only set either transparency or greyscale'

        # 设置数据增强相关参数
        self.aug_prob = aug_prob
        self.aug_types = aug_types

        # 设置学习率、工作进程数、TTUR倍数、批量大小、梯度积累等参数
        self.lr = lr
        self.num_workers = num_workers
        self.ttur_mult = ttur_mult
        self.batch_size = batch_size
        self.gradient_accumulate_every = gradient_accumulate_every

        # 设置梯度惩罚权重
        self.gp_weight = gp_weight

        # 设置评估和保存模型的频率
        self.evaluate_every = evaluate_every
        self.save_every = save_every
        self.steps = 0

        # 初始化损失值
        self.d_loss = 0
        self.g_loss = 0
        self.last_gp_loss = None
        self.last_recon_loss = None
        self.last_fid = None

        # 初始化文件夹
        self.init_folders()

        self.loader = None
        self.dataset_aug_prob = dataset_aug_prob

        # 设置计算 FID 的频率和数量
        self.calculate_fid_every = calculate_fid_every
        self.calculate_fid_num_images = calculate_fid_num_images
        self.clear_fid_cache = clear_fid_cache

        # 设置是否使用分布式训练
        self.is_ddp = is_ddp
        self.is_main = rank == 0
        self.rank = rank
        self.world_size = world_size

        # 设置混合精度训练
        self.amp = amp
        self.G_scaler = GradScaler(enabled = self.amp)
        self.D_scaler = GradScaler(enabled = self.amp)

    # 返回图片扩展名
    @property
    def image_extension(self):
        return 'jpg' if not self.transparent else 'png'

    # 返回检查点编号
    @property
    def checkpoint_num(self):
        return floor(self.steps // self.save_every)
        
    # 初始化 GAN 模型
    def init_GAN(self):
        args, kwargs = self.GAN_params

        # 实例化 GAN 模型
        self.GAN = Transganformer(
            lr = self.lr,
            latent_dim = self.latent_dim,
            image_size = self.image_size,
            ttur_mult = self.ttur_mult,
            fmap_max = self.fmap_max,
            transparent = self.transparent,
            greyscale = self.greyscale,
            rank = self.rank,
            *args,
            **kwargs
        )

        # 如果使用分布式训练,设置相关参数
        if self.is_ddp:
            ddp_kwargs = {'device_ids': [self.rank], 'output_device': self.rank, 'find_unused_parameters': True}

            self.G_ddp = DDP(self.GAN.G, **ddp_kwargs)
            self.D_ddp = DDP(self.GAN.D, **ddp_kwargs)
            self.D_aug_ddp = DDP(self.GAN.D_aug, **ddp_kwargs)

    # 写入配置文件
    def write_config(self):
        self.config_path.write_text(json.dumps(self.config()))
    # 加载配置信息,如果配置文件不存在则使用默认配置,否则读取配置文件内容
    def load_config(self):
        config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text())
        # 设置图像大小和透明度
        self.image_size = config['image_size']
        self.transparent = config['transparent']
        # 设置是否为灰度图像,并移除配置中的灰度信息
        self.greyscale = config.pop('greyscale', False)
        # 移除配置中的 fmap_max 信息
        self.fmap_max = config.pop('fmap_max', 512)
        # 删除 GAN 属性
        del self.GAN
        # 初始化 GAN
        self.init_GAN()

    # 返回配置信息
    def config(self):
        return {
            'image_size': self.image_size,
            'transparent': self.transparent,
            'greyscale': self.greyscale
        }

    # 设置数据源文件夹
    def set_data_src(self, folder):
        # 计算默认的工作线程数
        num_workers = default(self.num_workers, math.ceil(NUM_CORES / self.world_size))
        # 创建图像数据集
        self.dataset = ImageDataset(folder, self.image_size, transparent=self.transparent, greyscale=self.greyscale, aug_prob=self.dataset_aug_prob)
        # 创建分布式采样器
        sampler = DistributedSampler(self.dataset, rank=self.rank, num_replicas=self.world_size, shuffle=True) if self.is_ddp else None
        # 创建数据加载器
        dataloader = DataLoader(self.dataset, num_workers=num_workers, batch_size=math.ceil(self.batch_size / self.world_size), sampler=sampler, shuffle=not self.is_ddp, drop_last=True, pin_memory=True)
        # 创建数据加载器的循环迭代器
        self.loader = cycle(dataloader)

        # 如果数据集较小,自动设置数据增强概率
        num_samples = len(self.dataset)
        if not exists(self.aug_prob) and num_samples < 1e5:
            self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6)
            print(f'autosetting augmentation probability to {round(self.aug_prob * 100)}%')

    # 评估生成器的效果
    @torch.no_grad()
    def evaluate(self, num=0, num_image_tiles=4):
        self.GAN.eval()

        ext = self.image_extension
        num_rows = num_image_tiles

        latent_dim = self.GAN.latent_dim
        image_size = self.GAN.image_size

        # 生成潜在向量
        latents = torch.randn((num_rows ** 2, latent_dim)).cuda(self.rank)

        # 生成普通图像
        generated_images = self.generate_(self.GAN.G, latents)
        torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows)
        
        # 生成移动平均图像
        generated_images = self.generate_(self.GAN.GE, latents)
        torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-ema.{ext}'), nrow=num_rows)

    # 生成图像
    @torch.no_grad()
    def generate(self, num=0, num_image_tiles=4, checkpoint=None, types=['default', 'ema']):
        self.GAN.eval()

        latent_dim = self.GAN.latent_dim
        dir_name = self.name + str('-generated-') + str(checkpoint)
        dir_full = Path().absolute() / self.results_dir / dir_name
        ext = self.image_extension

        # 如果目录不存在,则创建目录
        if not dir_full.exists():
            os.mkdir(dir_full)

        # 生成普通图像
        if 'default' in types:
            for i in tqdm(range(num_image_tiles), desc='Saving generated default images'):
                latents = torch.randn((1, latent_dim)).cuda(self.rank)
                generated_image = self.generate_(self.GAN.G, latents)
                path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}.{ext}')
                torchvision.utils.save_image(generated_image[0], path, nrow=1)

        # 生成移动平均图像
        if 'ema' in types:
            for i in tqdm(range(num_image_tiles), desc='Saving generated EMA images'):
                latents = torch.randn((1, latent_dim)).cuda(self.rank)
                generated_image = self.generate_(self.GAN.GE, latents)
                path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}-ema.{ext}')
                torchvision.utils.save_image(generated_image[0], path, nrow=1)

        return dir_full

    @torch.no_grad()
    # 显示训练进度的方法,生成进度图像
    def show_progress(self, num_images=4, types=['default', 'ema']):
        # 获取所有检查点
        checkpoints = self.get_checkpoints()
        # 确保存在检查点以创建训练进度视频
        assert exists(checkpoints), 'cannot find any checkpoints to create a training progress video for'

        # 创建目录名
        dir_name = self.name + str('-progress')
        # 获取完整目录路径
        dir_full = Path().absolute() / self.results_dir / dir_name
        # 获取图像扩展名
        ext = self.image_extension
        # 初始化潜在向量
        latents = None

        # 计算检查点数的位数
        zfill_length = math.ceil(math.log10(len(checkpoints)))

        # 如果目录不存在,则创建目录
        if not dir_full.exists():
            os.mkdir(dir_full)

        # 遍历检查点,生成进度图像
        for checkpoint in tqdm(checkpoints, desc='Generating progress images'):
            # 加载检查点
            self.load(checkpoint, print_version=False)
            self.GAN.eval()

            # 初始化潜在向量
            if checkpoint == 0:
                latents = torch.randn((num_images, self.GAN.latent_dim)).cuda(self.rank)

            # 生成正常图像
            if 'default' in types:
                generated_image = self.generate_(self.GAN.G, latents)
                path = str(self.results_dir / dir_name / f'{str(checkpoint).zfill(zfill_length)}.{ext}')
                torchvision.utils.save_image(generated_image, path, nrow=num_images)

            # 生成移动平均图像
            if 'ema' in types:
                generated_image = self.generate_(self.GAN.GE, latents)
                path = str(self.results_dir / dir_name / f'{str(checkpoint).zfill(zfill_length)}-ema.{ext}')
                torchvision.utils.save_image(generated_image, path, nrow=num_images)

    # 计算 FID 分数的方法
    @torch.no_grad()
    def calculate_fid(self, num_batches):
        # 导入 FID 分数计算模块
        from pytorch_fid import fid_score
        # 清空 GPU 缓存
        torch.cuda.empty_cache()

        # 真实图像路径和生成图像路径
        real_path = self.fid_dir / 'real'
        fake_path = self.fid_dir / 'fake'

        # 删除用于 FID 计算的现有文件并重新创建目录
        if not real_path.exists() or self.clear_fid_cache:
            rmtree(real_path, ignore_errors=True)
            os.makedirs(real_path)

            # 保存真实图像
            for batch_num in tqdm(range(num_batches), desc='calculating FID - saving reals'):
                real_batch = next(self.loader)
                for k, image in enumerate(real_batch.unbind(0)):
                    ind = k + batch_num * self.batch_size
                    torchvision.utils.save_image(image, real_path / f'{ind}.png')

        # 删除生成图像目录并重新创建
        rmtree(fake_path, ignore_errors=True)
        os.makedirs(fake_path)

        # 设置生成器为评估模式
        self.GAN.eval()
        ext = self.image_extension

        # 获取潜在向量维度和图像尺寸
        latent_dim = self.GAN.latent_dim
        image_size = self.GAN.image_size

        # 生成假图像
        for batch_num in tqdm(range(num_batches), desc='calculating FID - saving generated'):
            # 生成潜在向量
            latents = torch.randn(self.batch_size, latent_dim).cuda(self.rank)

            # 生成移动平均图像
            generated_images = self.generate_(self.GAN.GE, latents)

            for j, image in enumerate(generated_images.unbind(0)):
                ind = j + batch_num * self.batch_size
                torchvision.utils.save_image(image, str(fake_path / f'{str(ind)}-ema.{ext}'))

        # 返回 FID 分数
        return fid_score.calculate_fid_given_paths([str(real_path), str(fake_path)], 256, latents.device, 2048)

    # 生成图像的方法
    @torch.no_grad()
    def generate_(self, G, style, num_image_tiles = 8):
        # 分块评估生成图像
        generated_images = evaluate_in_chunks(self.batch_size, G, style)
        return generated_images.clamp_(0., 1.)

    @torch.no_grad()
    # 生成插值图像序列
    def generate_interpolation(self, num = 0, num_image_tiles = 8, num_steps = 100, save_frames = False):
        # 将 GAN 设置为评估模式
        self.GAN.eval()
        # 获取图像文件扩展名
        ext = self.image_extension
        # 设置图像行数
        num_rows = num_image_tiles

        # 获取潜在空间维度和图像尺寸
        latent_dim = self.GAN.latent_dim
        image_size = self.GAN.image_size

        # 生成低和高潜在向量
        latents_low = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank)
        latents_high = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank)

        # 生成插值比例
        ratios = torch.linspace(0., 8., num_steps)

        frames = []
        # 对每个比例进行插值
        for ratio in tqdm(ratios):
            # 使用球面线性插值生成插值潜在向量
            interp_latents = slerp(ratio, latents_low, latents_high)
            # 生成图像
            generated_images = self.generate_(self.GAN.GE, interp_latents)
            # 将生成的图像排列成网格
            images_grid = torchvision.utils.make_grid(generated_images, nrow = num_rows)
            # 将图像网格转换为 PIL 图像
            pil_image = transforms.ToPILImage()(images_grid.cpu())
            
            # 如果需要透明背景
            if self.transparent:
                background = Image.new('RGBA', pil_image.size, (255, 255, 255))
                pil_image = Image.alpha_composite(background, pil_image)
                
            # 将当前帧添加到帧列表中
            frames.append(pil_image)

        # 保存插值图像序列为 GIF
        frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'), save_all=True, append_images=frames[1:], duration=80, loop=0, optimize=True)

        # 如果需要保存每一帧
        if save_frames:
            folder_path = (self.results_dir / self.name / f'{str(num)}')
            folder_path.mkdir(parents=True, exist_ok=True)
            for ind, frame in enumerate(frames):
                frame.save(str(folder_path / f'{str(ind)}.{ext}')

    # 打印日志信息
    def print_log(self):
        # 定义日志数据
        data = [
            ('G', self.g_loss),
            ('D', self.d_loss),
            ('GP', self.last_gp_loss),
            ('SS', self.last_recon_loss),
            ('FID', self.last_fid)
        ]

        # 过滤掉空值
        data = [d for d in data if exists(d[1])]
        # 将日志数据格式化为字符串
        log = ' | '.join(map(lambda n: f'{n[0]}: {n[1]:.2f}', data))
        # 打印日志
        print(log)

    # 返回模型���件名
    def model_name(self, num):
        return str(self.models_dir / self.name / f'model_{num}.pt')

    # 初始化文件夹
    def init_folders(self):
        # 创建结果文件夹和模型文件夹
        (self.results_dir / self.name).mkdir(parents=True, exist_ok=True)
        (self.models_dir / self.name).mkdir(parents=True, exist_ok=True)

    # 清空文件夹
    def clear(self):
        # 删除模型文件夹、结果文件夹、FID 文件夹和配置文件夹
        rmtree(str(self.models_dir / self.name), True)
        rmtree(str(self.results_dir / self.name), True)
        rmtree(str(self.fid_dir), True)
        rmtree(str(self.config_path), True)
        # 初始化文件夹
        self.init_folders()

    # 保存模型
    def save(self, num):
        # 保存模型数据
        save_data = {
            'GAN': self.GAN.state_dict(),
            'version': __version__,
            'G_scaler': self.G_scaler.state_dict(),
            'D_scaler': self.D_scaler.state_dict()
        }

        # 将数据保存到文件
        torch.save(save_data, self.model_name(num))
        # 写入配置文件
        self.write_config()

    # 加载模型
    def load(self, num=-1, print_version=True):
        # 加载配置文件
        self.load_config()

        name = num
        if num == -1:
            checkpoints = self.get_checkpoints()

            if not exists(checkpoints):
                return

            name = checkpoints[-1]
            print(f'continuing from previous epoch - {name}')

        self.steps = name * self.save_every

        load_data = torch.load(self.model_name(name))

        if print_version and 'version' in load_data and self.is_main:
            print(f"loading from version {load_data['version']}")

        try:
            self.GAN.load_state_dict(load_data['GAN'])
        except Exception as e:
            print('unable to load save model. please try downgrading the package to the version specified by the saved model')
            raise e

        if 'G_scaler' in load_data:
            self.G_scaler.load_state_dict(load_data['G_scaler'])
        if 'D_scaler' in load_data:
            self.D_scaler.load_state_dict(load_data['D_scaler'])
    # 获取所有检查点文件的路径列表
    def get_checkpoints(self):
        # 使用列表推导式获取所有以'model_'开头的文件路径
        file_paths = [p for p in Path(self.models_dir / self.name).glob('model_*.pt')]
        # 使用map函数和lambda表达式将文件路径转换为对应的数字编号,并按编号排序
        saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths))

        # 如果没有找到任何检查点文件,则返回None
        if len(saved_nums) == 0:
            return None

        # 返回排序后的检查点编号列表
        return saved_nums

.\lucidrains\transganformer\transganformer\version.py

# 定义当前代码的版本号为 '0.0.17'
__version__ = '0.0.17'

.\lucidrains\transganformer\transganformer\__init__.py

# 从 transganformer.transganformer 模块中导入 Transganformer, Generator, Discriminator, Trainer, NanException 类
from transganformer.transganformer import Transganformer, Generator, Discriminator, Trainer, NanException

Triangle Multiplicative Module - Pytorch

Implementation of the Triangle Multiplicative module, used in Alphafold2 as an efficient way to mix rows or columns of a 2d feature map, as a standalone package for Pytorch

Install

$ pip install triangle-multiplicative-module

Usage

import torch
from triangle_multiplicative_module import TriangleMultiplicativeModule

model = TriangleMultiplicativeModule(
    dim = 64,            # feature map dimension
    hidden_dim = 128,    # intermediate dimension size
    mix = 'outgoing'     # either 'ingoing' or 'outgoing'
)

fmap = torch.randn(1, 256, 256, 64)
mask = torch.ones(1, 256, 256).bool()

model(fmap, mask = mask) # (1, 256, 256, 64)

Citations

@Article{AlphaFold2021,
    author  = {Jumper, John and Evans, Richard and Pritzel, Alexander and Green, Tim and Figurnov, Michael and Ronneberger, Olaf and Tunyasuvunakool, Kathryn and Bates, Russ and {\v{Z}}{\'\i}dek, Augustin and Potapenko, Anna and Bridgland, Alex and Meyer, Clemens and Kohl, Simon A A and Ballard, Andrew J and Cowie, Andrew and Romera-Paredes, Bernardino and Nikolov, Stanislav and Jain, Rishub and Adler, Jonas and Back, Trevor and Petersen, Stig and Reiman, David and Clancy, Ellen and Zielinski, Michal and Steinegger, Martin and Pacholska, Michalina and Berghammer, Tamas and Bodenstein, Sebastian and Silver, David and Vinyals, Oriol and Senior, Andrew W and Kavukcuoglu, Koray and Kohli, Pushmeet and Hassabis, Demis},
    journal = {Nature},
    title   = {Highly accurate protein structure prediction with {AlphaFold}},
    year    = {2021},
    doi     = {10.1038/s41586-021-03819-2},
    note    = {(Accelerated article preview)},
}

.\lucidrains\triangle-multiplicative-module\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'triangle-multiplicative-module',  # 包名
  packages = find_packages(),  # 查找所有包
  version = '0.0.3',  # 版本号
  license='MIT',  # 许可证
  description = 'Triangle Multiplicative Module',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/triangle-multiplicative-module',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'protein folding'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'torch>=1.7'
  ],
  setup_requires=[  # 设置依赖
    'pytest-runner',
  ],
  tests_require=[  # 测试依赖
    'pytest'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\triangle-multiplicative-module\triangle_multiplicative_module\triangle_multiplicative_module.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 定义辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义类

# 三角形乘法模块类
class TriangleMultiplicativeModule(nn.Module):
    def __init__(
        self,
        *,
        dim,
        hidden_dim = None,
        mix = 'ingoing'
    ):
        super().__init__()
        # 断言 mix 参数只能为 'ingoing' 或 'outgoing'
        assert mix in {'ingoing', 'outgoing'}, 'mix must be either ingoing or outgoing'

        # 如果 hidden_dim 不存在,则设为 dim
        hidden_dim = default(hidden_dim, dim)
        # 对输入进行 LayerNorm 归一化
        self.norm = nn.LayerNorm(dim)

        # 左投影层
        self.left_proj = nn.Linear(dim, hidden_dim)
        # 右投影层
        self.right_proj = nn.Linear(dim, hidden_dim)

        # 左门控层
        self.left_gate = nn.Linear(dim, hidden_dim)
        # 右门控层
        self.right_gate = nn.Linear(dim, hidden_dim)
        # 输出门控层
        self.out_gate = nn.Linear(dim, hidden_dim)

        # 初始化所有门控层的权重为 0,偏置为 1
        for gate in (self.left_gate, self.right_gate, self.out_gate):
            nn.init.constant_(gate.weight, 0.)
            nn.init.constant_(gate.bias, 1.)

        # 根据 mix 参数确定 einsum 公式
        if mix == 'outgoing':
            self.mix_einsum_eq = '... i k d, ... j k d -> ... i j d'
        elif mix == 'ingoing':
            self.mix_einsum_eq = '... k j d, ... k i d -> ... i j d'

        # 输出层归一化
        self.to_out_norm = nn.LayerNorm(hidden_dim)
        # 输出层线性变换
        self.to_out = nn.Linear(hidden_dim, dim)

    def forward(self, x, mask = None):
        # 断言输入特征图必须是对称的
        assert x.shape[1] == x.shape[2], 'feature map must be symmetrical'
        # 如果 mask 存在,则重组 mask 的维度
        if exists(mask):
            mask = rearrange(mask, 'b i j -> b i j ()')

        # 对输入进行归一化
        x = self.norm(x)

        # 左投影
        left = self.left_proj(x)
        # 右投影
        right = self.right_proj(x)

        # 如果 mask 存在,则对左右投影进行 mask 处理
        if exists(mask):
            left = left * mask
            right = right * mask

        # 计算左门控值
        left_gate = self.left_gate(x).sigmoid()
        # 计算右门控值
        right_gate = self.right_gate(x).sigmoid()
        # 计算输出门控值
        out_gate = self.out_gate(x).sigmoid()

        # 左投影乘以左门控值
        left = left * left_gate
        # 右投影乘以右门控值
        right = right * right_gate

        # 执行 einsum 运算,根据 mix_einsum_eq 公式计算输出
        out = einsum(self.mix_einsum_eq, left, right)

        # 对输出进行归一化
        out = self.to_out_norm(out)
        # 输出乘以输出门控值
        out = out * out_gate
        # 返回输出结果
        return self.to_out(out)

.\lucidrains\triangle-multiplicative-module\triangle_multiplicative_module\__init__.py

# 从triangle_multiplicative_module.triangle_multiplicative_module模块中导入TriangleMultiplicativeModule类
from triangle_multiplicative_module.triangle_multiplicative_module import TriangleMultiplicativeModule

.\lucidrains\triton-transformer\assert.py

# 导入 PyTorch 库
import torch
# 从 triton_transformer 模块中导入 Transformer 类
from triton_transformer import Transformer

# 检查是否有可用的 CUDA 设备
assert torch.cuda.is_available()

# 实例化模型和数据

# 创建 Transformer 模型对象,设置参数:标记数量为 256,最大序列长度为 1024,维度为 512,深度为 6,头数为 8,头维度为 64,使用因果性,不使用 Triton
model = Transformer(
    num_tokens = 256,
    max_seq_len = 1024,
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    causal = True,
    use_triton = False
).cuda()

# 生成一个大小为 (1, 1024) 的张量,填充随机整数,放在 CUDA 设备上
x = torch.randint(0, 256, (1, 1024)).cuda()
# 生成一个大小为 (1, 1024) 的张量,填充随机整数,放在 CUDA 设备上
labels = torch.randint(0, 256, (1, 1024)).cuda()

# 无 Triton 的前向传播和反向传播

# 计算模型输出和损失
loss = model(x, labels = labels)
# 反向传播计算梯度
loss.backward()

# 复制损失值
loss = loss.clone()
# 复制 token embeddings 的梯度
emb_grad = model.token_emb.weight.grad.clone()
# 复制 LayerNorm 层的权重梯度
ln_weight_grad = model.norm.weight.grad.clone()
# 复制 LayerNorm 层的偏置梯度
ln_bias_grad = model.norm.bias.grad.clone()

# 清零所有梯度
model.zero_grad()

# Triton 的前向传播和反向传播

# 使用 Triton 进行前向传播和反向传播
triton_loss = model(x, labels = labels, use_triton = True)
# Triton 反向传播计算梯度
triton_loss.backward()

# 复制 Triton 下的 token embeddings 的梯度
triton_emb_grad = model.token_emb.weight.grad.clone()
# 复制 Triton 下的 LayerNorm 层的权重梯度
triton_ln_weight_grad = model.norm.weight.grad.clone()
# 复制 Triton 下的 LayerNorm 层的偏置梯度
triton_ln_bias_grad = model.norm.bias.grad.clone()

# 应该相等,对输出和 token embeddings 的梯度进行检查

# 检查输出是否相等
assert torch.allclose(loss.cpu(), triton_loss.cpu(), atol=1e-6), 'output is the same'
# 检查 token embeddings 的梯度是否相等
assert torch.allclose(emb_grad.cpu(), triton_emb_grad.cpu(), atol=2e-6), 'grad is the same'
# 检查 LayerNorm 层的权重梯度是否相等
assert torch.allclose(ln_weight_grad.cpu(), triton_ln_weight_grad.cpu(), atol=2e-6), 'layernorm weight grad is the same'
# 检查 LayerNorm 层的偏置梯度是否相等
assert torch.allclose(ln_bias_grad.cpu(), triton_ln_bias_grad.cpu(), atol=2e-6), 'layernorm bias grad is the same'

# 打印成功信息
print('succeeded')

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

Transformer in Triton (wip)

Implementation of a Transformer, but completely in Triton. I'm completely new to lower-level neural net code, so this repository will mostly be a learning experience, with the end-goal being a vanilla transformer that is faster and more efficient to train.

Results

Layernorm forward

Layernorm forwards and backwards

Softmax forwards and backwards

Install

$ pip install triton-transformer

Usage

import torch
from triton_transformer import Transformer

model = Transformer(
    num_tokens = 256,       # vocab size
    max_seq_len = 1024,     # maximum sequence length
    dim = 512,              # dimension
    depth = 6,              # depth
    heads = 8,              # number of heads
    dim_head = 64,          # dimension per head
    causal = True,          # autoregressive or not
    attn_dropout = 0.1,     # attention dropout
    ff_dropout = 0.1,       # feedforward dropout
    use_triton = True       # use this to turn on / off triton
).cuda()

x = torch.randint(0, 256, (1, 1024)).cuda()
logits = model(x) # (1, 1024, 256)

To train, just pass in the labels with the keyword labels on forward, and the cross entropy loss will be returned for backprop.

ex. BERT

import torch
from triton_transformer import Transformer

model = Transformer(
    num_tokens = 20000,
    max_seq_len = 512,
    dim = 512,
    depth = 12,
    heads = 8,
    dim_head = 64,
    use_triton = True
).cuda()

x = torch.randint(0, 20000, (1, 512)).cuda()
labels = torch.randint(0, 20000, (1, 512)).cuda()
mask = torch.ones(1, 512).bool().cuda()

loss = model(x, mask = mask, labels = labels)
loss.backward()

Test - GPT training

$ python train.py

Todo

Citations

@article{Tillet2019TritonAI,
    title   = {Triton: an intermediate language and compiler for tiled neural network computations},
    author  = {Philippe Tillet and H. Kung and D. Cox},
    journal = {Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages},
    year    = {2019}
}
@misc{vaswani2017attention,
    title   = {Attention Is All You Need}, 
    author  = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
    year    = {2017},
    eprint  = {1706.03762},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{so2021primer,
    title   = {Primer: Searching for Efficient Transformers for Language Modeling},
    author  = {David R. So and Wojciech Mańke and Hanxiao Liu and Zihang Dai and Noam Shazeer and Quoc V. Le},
    year    = {2021},
    eprint  = {2109.08668},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@article{chowdhery2022PaLM,
  title   = {PaLM: Scaling Language Modeling with Pathways},
  author  = {Chowdhery, Aakanksha et al},
  year    = {2022}
}

.\lucidrains\triton-transformer\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'triton-transformer', # 包的名称
  packages = find_packages(), # 查找所有包
  version = '0.1.1', # 版本号
  license='MIT', # 许可证
  description = 'Transformer in Triton', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/triton-transformer', # 项目链接
  keywords = [
    'artificial intelligence', # 关键词
    'attention mechanism', # 关键词
    'transformers' # 关键词
  ],
  install_requires=[
    'einops', # 安装所需的依赖包
    'torch>=1.6', # 安装所需的依赖包
    'triton==1.0.1.dev20210924' # 安装所需的依赖包
  ],
  classifiers=[
    'Development Status :: 4 - Beta', # 分类器
    'Intended Audience :: Developers', # 分类器
    'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类器
    'License :: OSI Approved :: MIT License', # 分类器
    'Programming Language :: Python :: 3.6', # 分类器
  ],
)

.\lucidrains\triton-transformer\train.py

# 导入所需的库
from triton_transformer import Transformer
from triton_transformer.autoregressive_wrapper import AutoregressiveWrapper

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512

# 定义辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 实例化类似 GPT 的解码器模型
model = Transformer(
    num_tokens = 256,
    dim = 512,
    max_seq_len = SEQ_LEN,
    depth = 8,
    heads = 8,
    causal = True,
    use_triton = True,
    attn_dropout = 0.1,
    ff_dropout = 0.1,
)

model = AutoregressiveWrapper(model)
model.cuda()

# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义文本采样数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        print(output_str)

.\lucidrains\triton-transformer\triton_transformer\autoregressive_wrapper.py

import torch
from torch import nn
import torch.nn.functional as F

# helper function

# 检查值是否存在的辅助函数
def exists(val):
    return val is not None

# 评估装饰器函数
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        # 保存模型原始训练状态
        was_training = model.training
        # 将模型设置为评估模式
        model.eval()
        # 调用传入的函数
        out = fn(model, *args, **kwargs)
        # 恢复模型原始训练状态
        model.train(was_training)
        return out
    return inner

# top k filtering

# 根据阈值过滤 logits 中的 top k 值
def top_k(logits, thres = 0.9):
    # 计算 top k 的数量
    k = int((1 - thres) * logits.shape[-1])
    # 获取 top k 的值和索引
    val, ind = torch.topk(logits, k)
    # 创建与 logits 相同形状的全为负无穷的张量
    probs = torch.full_like(logits, float('-inf'))
    # 根据索引将 top k 的值填充到 probs 中
    probs.scatter_(1, ind, val)
    return probs

# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, pad_value = 0):
        super().__init__()
        self.pad_value = pad_value
        self.net = net
        self.max_seq_len = net.max_seq_len

    # 生成序列的方法
    @torch.no_grad()
    @eval_decorator
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_thres = 0.9, **kwargs):
        # 获取起始 tokens 的形状和设备信息
        b, t, device = *start_tokens.shape, start_tokens.device

        out = start_tokens

        for _ in range(seq_len):
            # 获取最后 self.max_seq_len 个 token
            x = out[:, -self.max_seq_len:]

            # 获取模型预测的 logits
            logits = self.net(x, **kwargs)[:, -1, :]

            # 过滤 logits 中的 top k 值
            filtered_logits = top_k(logits, thres = filter_thres)
            # 计算 softmax 温度调节后的概率
            probs = F.softmax(filtered_logits / temperature, dim=-1)

            # 从概率分布中采样一个 token
            sample = torch.multinomial(probs, 1)

            # 将采样的 token 添加到输出序列中
            out = torch.cat((out, sample), dim=-1)

            if exists(eos_token):
                # 检查是否存在 eos_token
                is_eos_token = (out == eos_token)

                if is_eos_token.any(dim = -1).all():
                    # 如果所有序列中都存在 eos_token,则停止生成
                    # 创建一个向右移动一位�� eos_token mask
                    shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
                    # 创建一个 mask,标记 eos_token 后的所有位置
                    mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
                    # 将 mask 标记的位置填充为 pad_value
                    out = out.masked_fill(mask, self.pad_value)
                    break

        # 去除起始 tokens,返回生成的序列
        out = out[:, t:]
        return out

    # 前向传播方法
    def forward(self, x, **kwargs):
        # 将输入拆分为输入和标签
        x_inp, x_labels = x[:, :-1], x[:, 1:]
        return self.net(x_inp, labels = x_labels, **kwargs)

.\lucidrains\triton-transformer\triton_transformer\bmm.py

# 导入 torch 库
import torch
# 从 torch 库中导入 autograd 模块
from torch import autograd
# 从 torch.nn.functional 模块中导入 F 函数
import torch.nn.functional as F

# 从 triton_transformer.utils 模块中导入 calc_num_warps 和 exists 函数
from triton_transformer.utils import calc_num_warps, exists

# 导入 triton 库
import triton
# 从 triton.language 模块中导入 tl
import triton.language as tl

# 使用 triton.autotune 装饰器,配置自动调优参数
@triton.autotune(
    configs=[
        # 配置不同的参数组合
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_M': 32 , 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
    ],
    key=['M', 'N', 'K'],
)
# 使用 triton.jit 装饰器,编译 bmm_kernel 函数
@triton.jit
def bmm_kernel(
    x_ptr, y_ptr, o_ptr,
    M, N, K,
    stride_al, stride_am, stride_ak,
    stride_bl, stride_bk, stride_bn,
    stride_ol, stride_om, stride_on,
    **meta,
):
    # 定义常量
    BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
    BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
    BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
    GROUP_SIZE_M = 8

    # 计算程序 ID
    pid_batch = tl.program_id(0)
    pid = tl.program_id(1)

    # 计算分组数量
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # 计算偏移量
    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    x_ptrs = x_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak + pid_batch*stride_al)
    y_ptrs = y_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn + pid_batch*stride_bl)

    # 初始化输出矩阵 o
    o = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    # 循环计算矩阵乘法
    for k in range(0, K, BLOCK_SIZE_K):
        x = tl.load(x_ptrs)
        y = tl.load(y_ptrs)
        o += tl.dot(x, y)

        x_ptrs += BLOCK_SIZE_K * stride_ak
        y_ptrs += BLOCK_SIZE_K * stride_bk

    # 如果存在激活函数,则应用激活函数
    if exists(meta['ACTIVATION']):
        o = meta['ACTIVATION'](o)

    # 计算偏移量
    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    # 创建掩码
    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)

    # 计算输出指针
    o_ptrs = o_ptr + stride_om * offs_m[:, None] + stride_on * offs_n[None, :] + stride_ol * pid_batch
    # 存储结果到输出指针
    tl.store(o_ptrs, o, mask=mask)

# 定义 triton_bmm 函数
def triton_bmm(x, y, activation = None):
    # 获取 x 的形状信息
    B, M, K = x.shape

    # 如果 y 的维度为 2,则扩展维度
    if y.ndim == 2:
        y = y.unsqueeze(0).expand(B, -1, -1)

    # 获取 y 的形状信息
    _, K, N = y.shape
    # 断言 K 必须能被 32 整除
    assert (K % 32 == 0), "K must be divisible by 32"

    # 创建输出张量 o
    o = torch.empty((B, M, N), device = x.device, dtype = x.dtype)

    # 定义 grid 函数
    grid = lambda META: (
        B, triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
    )

    # 调用 bmm_kernel 函数
    bmm_kernel[grid](
        x, y, o,
        M, N, K,
        x.stride(0), x.stride(1), x.stride(2),
        y.stride(0), y.stride(1), y.stride(2),
        o.stride(0), o.stride(1), o.stride(2),
        ACTIVATION = activation
    )
    # 返回结果张量 o
    return o

# 使用 triton.jit 装饰器,编译 relu_squared_activation 函数
@triton.jit
def relu_squared_activation(x):
    return tl.where(x > 0, x * x, 0.)

# 定义 _relu_squared 类
class _relu_squared(autograd.Function):
    # 前向传播函数
    @classmethod
    def forward(self, ctx, x, w):
        # 调用 triton_bmm 函数,应用 relu_squared_activation 激活函数
        o = triton_bmm(x, w, activation = relu_squared_activation)
        # 如果 x 需要梯度,则保存相关信息
        if x.requires_grad:
            ctx.save_for_backward(x, w, o)
        return o

    @classmethod
    # 反向传播函数,接收上下文和梯度作为输入
    def backward(self, ctx, dy):
        # 从上下文中获取保存的张量 x, w, o
        x, w, o = ctx.saved_tensors
        # 计算 dy 乘以 o 的平方根乘以 2,得到新的梯度 dy
        dy = torch.sqrt(o) * 2 * dy
        # 计算 dy 与权重 w 的转置的矩阵乘积,得到输入 x 的梯度 dx
        dx = triton_bmm(dy, w.t())
        # 计算输入 x 的转置与梯度 dy 的矩阵乘积,得到权重 w 的梯度 dw
        dw = triton_bmm(x.transpose(-1, -2), dy)
        # 返回输入 x 和权重 w 的梯度
        return dx, dw
# 将 _relu_squared.apply 赋值给 triton_relu_squared,用于后续调用
triton_relu_squared = _relu_squared.apply

# 定义一个融合了 ReLU 和平方操作的函数
def fused_relu_squared(x, w, use_triton = False):
    # 如果 use_triton 为 True,则调用 triton_relu_squared 函数
    if use_triton:
        return triton_relu_squared(x, w)

    # 如果 use_triton 为 False,则计算 x @ w 的矩阵乘法结果,然后对结果进行 ReLU 和平方操作
    return F.relu(x @ w) ** 2
posted @ 2024-06-28 14:13  绝不原创的飞龙  阅读(9)  评论(0编辑  收藏  举报