Lucidrains-系列项目源码解析-二十三-

Lucidrains 系列项目源码解析(二十三)

.\lucidrains\lion-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'lion-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.1.2',  # 版本号
  license='MIT',  # 许可证
  description = 'Lion Optimizer - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/lion-pytorch',  # URL
  keywords = [
    'artificial intelligence',  # 关键词:人工智能
    'deep learning',  # 关键词:深度学习
    'optimizers'  # 关键词:优化器
  ],
  install_requires=[
    'torch>=1.6'  # 安装所需的依赖项
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类器:开发状态为 Beta
    'Intended Audience :: Developers',  # 分类器:目标受众为开发者
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类器:主题为科学/工程 - 人工智能
    'License :: OSI Approved :: MIT License',  # 分类器:许可证为 MIT
    'Programming Language :: Python :: 3.6',  # 分类器:编程语言为 Python 3.6
  ],
)

Liquid Conway Game of Life

Try online: https://lucidrains.github.io/liquid-conway/

Based on: http://www.jgallant.com/2d-liquid-simulator-with-cellular-automaton-in-unity/

$ yarn
$ npm start

In a new terminal

$ open index.html

.\lucidrains\liquid-conway\src\app.js

# 导入 RxJS 库
import Rx from 'rxjs';
# 导入 helpers 模块
import helpers from './helpers';
# 导入样式文件
require('./app.sass');

# 从 helpers 模块中解构出 range、initArray、cacheFn 函数
const {
  range,
  initArray,
  cacheFn
} = helpers;

# 获取 canvas 元素
const c = document.getElementById('canvas');

# 定义常量
const FPS = 30;
const WIDTH = 150;
const HEIGHT = 75;
const CELL_SIZE = 7;
const CANVAS_WIDTH = WIDTH * CELL_SIZE;
const CANVAS_HEIGHT = HEIGHT * CELL_SIZE;
const CELL_FILL_STYLE = 'rgb(22, 109, 175)';
const BACKGROUND_COLOR = 'rgba(255, 255, 255, 0.5)';
const NEIGHBOR_COORS_CACHE = {};

# 定义方向数组 DIR
const DIR = range(-1, 1)
  .reduce((acc, x) => acc.concat(range(-1, 1).map(y => [x, y])), [])
  .filter(([x, y]) => !(x === 0 && y === 0));

# 设置 canvas 元素的宽度和高度
c.setAttribute('width', CANVAS_WIDTH.toString());
c.setAttribute('height', CANVAS_HEIGHT.toString());
c.style.display = 'block';
# 获取 2D 绘图上下文
const ctx = c.getContext('2d');

# 初始化网格
function initGrid(x, y, init) {
  return initArray(x, init).map(() => initArray(y, init));
}

# 初始化 grid 和 buffer
let [
  grid,
  buffer
] = [
  initGrid(WIDTH, HEIGHT, 0),
  initGrid(WIDTH, HEIGHT, 0)
];

# 获取网格坐标
const GRID_COORS = grid.reduce((acc, row, x) => {
  acc = acc.concat(row.map((_, y) => [x, y]));
  return acc;
}, []);

# 在网格中随机生成初始状态
GRID_COORS.forEach(([x, y]) => {
  grid[x][y] = Math.round(Math.random());
});

# 创建 RxJS Observable,处理鼠标事件
Rx.Observable
  .fromEvent(c, 'mousedown')
  .flatMap((md) => {
    md.preventDefault();
    let ev = md;

    return Rx.Observable.merge(
        Rx.Observable.interval(10).map(() => null),
        Rx.Observable.fromEvent(c, 'mousemove')
      )
      .map((mm) => {
        ev = mm || ev;
        const { left, top } = ev.target.getBoundingClientRect();
        const x = ev.clientX - left;
        const y = ev.clientY - top;
        const [coorX, coorY] = [x, y].map(el => Math.floor(el / CELL_SIZE));
        return [coorX, coorY];
      })
      .takeUntil(Rx.Observable.fromEvent(c, 'mouseup'));
  })
  .throttleTime(10)
  .subscribe(([x, y]) => {
    grid[x][y] = 1;
  });

# 判断坐标是否在网格范围内
function withinBounds(grid, x, y) {
  return x >= 0 && x < grid.length && y >= 0 && y < grid[0].length;
}

# 获取邻居坐标
function getNeighborCoors(grid, x, y) {
  return DIR.reduce((acc, [dx, dy]) => {
    const [nx, ny] = [dx + x, dy + y];
    if (withinBounds(grid, nx, ny)) {
      acc.push([nx, ny]);
    }
    return acc;
  }, []);
}

# 使用缓存函数获取邻居坐标
const getCacheNeighborCoors = cacheFn(
  getNeighborCoors,
  NEIGHBOR_COORS_CACHE,
  (_, x, y) => `${x}:${y}`
);

# 计算邻居中存活细胞数量
function countNeighborsAlive(grid, x, y) {
  const neighbors = getCacheNeighborCoors(grid, x, y);

  return neighbors.reduce((acc, [nx, ny]) => {
    if (grid[nx][ny] === 1) {
      acc += 1;
    }
    return acc;
  }, 0);
}

# 计算下一个状态
function computeNextState(curr, neighbors) {
  return ((curr === 1 && neighbors === 2) || neighbors === 3) ? 1 : 0;
}

# 计算下一个状态
function nextState(grid, buffer) {
  GRID_COORS.forEach(([x, y]) => {
    const cell = grid[x][y];
    const count = countNeighborsAlive(grid, x, y);
    buffer[x][y] = computeNextState(cell, count);
  });
}

# 渲染函数
function render(ctx, grid) {
  ctx.fillStyle = BACKGROUND_COLOR;
  ctx.fillRect(0, 0, CANVAS_WIDTH, CANVAS_HEIGHT);

  GRID_COORS.forEach(([x, y]) => {
    const cell = grid[x][y];
    if (cell === 1) {
      ctx.fillStyle = CELL_FILL_STYLE;
      ctx.fillRect(
        (x * CELL_SIZE) + 1,
        (y * CELL_SIZE) + 1,
        CELL_SIZE - 1,
        CELL_SIZE - 1
      );
    }
  });
}

# 定义动画函数
let start;
const throttleDiff = (1000 / FPS);

function step() {
  const now = +new Date();
  start = start || now;
  const diff = now - start;
  start = now;

  render(ctx, grid);

  const callNextFrame = window.requestAnimationFrame.bind(null, step);
  if (diff > throttleDiff) {
    callNextFrame();
  } else {
    setTimeout(callNextFrame, throttleDiff - diff);
  }
}

# 启动动画
step();

# 定时更新状态
setInterval(() => {
  nextState(grid, buffer);
  [buffer, grid] = [grid, buffer];
}, 80);

.\lucidrains\liquid-conway\src\helpers.js

# 从 lodash.clone 模块中导入 clone 函数
import clone from 'lodash.clone';

# 初始化一个包含 num 个元素的数组,每个元素都是 init 的克隆
function initArray(num, init) {
    return Array.from(Array(num)).map(() => clone(init));
}

# 生成一个从 low 到 high 的范围数组,步长为 step,默认为 1
function range(low, high, step = 1) {
    const arr = [];
    for (let i = low; i <= high; i += step) {
        arr.push(i);
    }
    return arr;
}

# 缓存函数的结果,使用 cacheObj 存储结果,deriveKeyFn 用于生成缓存的键
function cacheFn(fn, cacheObj, deriveKeyFn) {
    return (...args) => {
        let key;
        if (!deriveKeyFn) {
            key = JSON.stringify(args);
        } else {
            key = deriveKeyFn(...args);
        }

        if (cacheObj[key] !== undefined) {
            return cacheObj[key];
        }

        const ret = fn(...args);
        cacheObj[key] = ret;
        return ret;
    };
}

# 生成一个小于 num 的随机整数
function randInt(num) {
    return Math.floor(Math.random() * (num + 1));
}

# 导出包含 cacheFn、range、initArray、randInt 函数的对象
export default {
    cacheFn,
    range,
    initArray,
    randInt
};

.\lucidrains\liquid-conway\src\liquid.js

# 导入 Rx 模块
import Rx from 'rxjs';
# 导入 helpers 模块
import helpers from './helpers';

# 导入样式表
require('./app.sass');

# 从 helpers 模块中导入 initArray 函数
const { initArray } = helpers;

# 获取 canvas 元素
const c = document.getElementById('canvas');

# 阻止右键菜单弹出
c.oncontextmenu = (e) => {
  e.preventDefault();
};

# 设置常量
const FPS = 30;
const WIDTH = 80;
const HEIGHT = 60;
const CELL_SIZE = 10;
const CANVAS_WIDTH = WIDTH * CELL_SIZE;
const CANVAS_HEIGHT = HEIGHT * CELL_SIZE;

const CELL_COLOR_LIGHTEST = 'rgb(0, 204, 255)';
const CELL_COLOR_LIGHT = 'rgb(0, 153, 255)';
const CELL_COLOR = 'rgb(0, 102, 255)';
const CELL_COLOR_DARK = 'rgb(51, 102, 255)';
const CELL_COLOR_DARKEST = 'rgb(51, 51, 204)';

const BACKGROUND_COLOR = 'rgb(255, 255, 255)';

# 初始化网格
function initGrid(x, y, init) {
  return initArray(x, init).map(() => initArray(y, init));
}

# 创建网格
const GRID = initGrid(WIDTH, HEIGHT, { val: 0, diff: 0 });

# 获取网格坐标
const GRID_COORS = GRID.reduce((acc, row, x) =>
  acc.concat(row.map((_, y) => [x, y]))
, []);

# 检查坐标是否在网格内
function withinBounds(grid, x, y) {
  return x >= 0 && x < grid.length && y >= 0 && y < grid[0].length;
}

# 检查网格中的单元格是否为空
function isEmptyCell(grid, x, y) {
  return withinBounds(grid, x, y) && !grid[x][y].wall;
}

# 设置 canvas 元素的宽度和高度
c.setAttribute('width', CANVAS_WIDTH.toString());
c.setAttribute('height', CANVAS_HEIGHT.toString());
c.style.display = 'block';

# 获取 2D 上下文
const ctx = c.getContext('2d');

# 合并鼠标和触摸事件的 Observable 流
Rx.Observable.merge(
    Rx.Observable.fromEvent(c, 'mousedown'),
    Rx.Observable.fromEvent(c, 'touchstart')
  )
  .flatMap((md) => {
    md.preventDefault();
    let ev = md;

    return Rx.Observable.merge(
        Rx.Observable.interval(10).map(() => null),
        Rx.Observable.fromEvent(c, 'mousemove'),
        Rx.Observable.fromEvent(c, 'touchmove')
      )
      .map((mm) => {
        ev = mm || ev;
        return { ev, which: md.which };
      })
      .takeUntil(Rx.Observable.merge(
        Rx.Observable.fromEvent(c, 'mouseup'),
        Rx.Observable.fromEvent(c, 'mouseout'),
        Rx.Observable.fromEvent(c, 'touchend')
      ));
  })
  .throttleTime(10)
  .subscribe(({ ev, which }) => {
    const { target, touches, type } = ev;
    const isTouch = type === 'touchmove' || type === 'touchstart';

    const { left, top } = target.getBoundingClientRect();
    const { clientX, clientY } = isTouch ? touches[0] : ev;

    const x = clientX - left;
    const y = clientY - top;
    const [cx, cy] = [x, y].map(el => Math.floor(el / CELL_SIZE));

    if (!withinBounds(GRID, cx, cy)) {
      return;
    }

    const cell = GRID[cx][cy];

    if (which === 1 || isTouch) {
      delete cell.wall;
      cell.val += 100;
    } else if (which === 3) {
      cell.wall = true;
      cell.val = 0;
    }
  });

# 计算下一个状态
function nextState(grid) {
  const withinGrid = withinBounds.bind(null, grid);

  GRID_COORS.forEach(([x, y]) => {
    const cell = grid[x][y];
    const val = cell.val;

    if (cell.wall || val < 0) {
      return;
    }

    if (withinGrid(x, y + 1) && grid[x][y + 1].val < 100) {
      cell.diff -= val;
      grid[x][y + 1].diff += val;
      return;
    }

    let volume = val;

    const flowCoors = [[1, 0], [-1, 0]]
      .filter(([dx, dy]) => {
        const [nx, ny] = [x + dx, y + dy];
        return withinGrid(nx, ny) && val > grid[nx][ny].val;
      });

    const diffs = flowCoors.map(([dx, dy]) => {
      const [nx, ny] = [x + dx, y + dy];
      const diff = val - grid[nx][ny].val;
      return diff;
    });

    const totalDiff = diffs.reduce((acc, diff) => {
      acc += diff;
      return acc;
    }, 0);

    const finalDiff = Math.min(volume, totalDiff);

    diffs.forEach((diff, i) => {
      const [dx, dy] = flowCoors[i];
      const weightedDiff = Math.floor(finalDiff * (diff / totalDiff)) / 2;

      grid[x][y].diff -= weightedDiff;
      grid[x + dx][y + dy].diff += weightedDiff;
      volume -= weightedDiff;
    });

    if (volume < 0) {
      return;
    }
    # 如果当前单元格上方的单元格在网格内且数值小于当前单元格的数值,并且当前单元格的数值大于100
    if (withinGrid(x, y - 1) && grid[x][y - 1].val < cell.val && cell.val > 100) {
      # 计算数值差值,将差值的一部分分配给上方单元格
      const diff = Math.floor((val - grid[x][y - 1].val) / 20);
      grid[x][y - 1].diff += diff;
      # 减去分配的差值
      cell.diff -= diff;
      # 更新总体差值
      volume -= diff;
    }

    # 如果当前单元格下方的单元格在网格内且数值小于当前单元格的数值
    if (withinGrid(x, y + 1) && grid[x][y + 1].val < cell.val) {
      # 计算数值差值,将差值的一部分分配给下方单元格
      const diff = Math.floor((val - grid[x][y + 1].val) / 10);
      grid[x][y + 1].diff += diff;
      # 减去分配的差值
      cell.diff -= diff;
      # 更新总体差值
      volume -= diff;
    }
  });

  # 遍历所有网格坐标
  GRID_COORS.forEach(([x, y]) => {
    # 获取当前单元格
    const cell = grid[x][y];
    # 更新单元格数值,重置差值为0
    cell.val += cell.diff;
    cell.diff = 0;
  });
// 渲染函数,根据传入的上下文和网格对象进行绘制
function render(context, grid) {
  // 设置背景颜色并填充整个画布
  context.fillStyle = BACKGROUND_COLOR;
  context.fillRect(0, 0, CANVAS_WIDTH, CANVAS_HEIGHT);

  // 遍历所有网格坐标
  GRID_COORS.forEach(([x, y]) => {
    // 获取当前坐标对应的单元格对象
    const cell = grid[x][y];

    // 如果单元格是墙壁
    if (cell.wall) {
      // 设置颜色为黑色并填充墙壁单元格
      context.fillStyle = 'black';
      context.fillRect(
        (x * CELL_SIZE) + 1,
        (y * CELL_SIZE) + 1,
        CELL_SIZE,
        CELL_SIZE
      );
    } else {
      // 如果单元格不是墙壁
      const val = cell.val;

      // 如果值小于等于0,则跳过
      if (val <= 0) {
        return;
      }

      let fillStyle = CELL_COLOR;
      let cellHeight = CELL_SIZE - 1;
      let cellY = (y * CELL_SIZE) + 1;

      // 检查是否有底部相邻单元格或者顶部无相邻单元格
      const hasBottomNeighbor = (!isEmptyCell(grid, x, y + 1) || grid[x][y + 1].val > 0);
      const hasNoTopNeighbor = (!isEmptyCell(grid, x, y - 1) || grid[x][y - 1].val <= 0);

      // 根据条件调整单元格高度和位置
      if (val < 100 && hasBottomNeighbor && hasNoTopNeighbor) {
        cellHeight *= parseFloat(val) / 100;
        cellY += (CELL_SIZE - cellHeight);
      }

      // 根据值的大小设置不同的颜色
      if (val < 50) {
        fillStyle = CELL_COLOR_LIGHTEST;
      } else if (val < 80) {
        fillStyle = CELL_COLOR_LIGHT;
      } else if (val > 150) {
        fillStyle = CELL_COLOR_DARKEST;
      } else if (val > 120) {
        fillStyle = CELL_COLOR_DARK;
      }

      // 设置颜色并填充单元格
      context.fillStyle = fillStyle;
      context.fillRect(
        (x * CELL_SIZE) + 1,
        cellY,
        CELL_SIZE - 1,
        cellHeight
      );
    }
  });
}

// 初始化时间变量和节流时间间隔
let start;
const throttleDiff = (1000 / FPS);

// 每一帧的处理函数
function step() {
  const now = +new Date();
  start = start || now;
  const diff = now - start;
  start = now;

  // 调用渲染函数
  render(ctx, GRID);

  // 请求下一帧动画
  const callNextFrame = window.requestAnimationFrame.bind(null, step);
  if (diff > throttleDiff) {
    callNextFrame();
  } else {
    setTimeout(callNextFrame, throttleDiff - diff);
  }
}

// 开始执行动画
step();

// 每50毫秒更新一次网格状态
setInterval(() => {
  nextState(GRID);
}, 50);

.\lucidrains\liquid-conway\webpack.config.js

# 引入 Node.js 的 path 模块和 ExtractTextPlugin 插件
const path = require('path');
const ExtractTextPlugin = require('extract-text-webpack-plugin');

# 定义源代码目录和输出目录的绝对路径
const src = path.resolve(__dirname, 'src');
const dist = path.resolve(__dirname, 'dist');

# 配置对象,包括上下文、入口文件、输出文件、模块规则和插件
const config = {
  # 指定上下文为源代码目录
  context: src,
  # 配置入口文件,包括 regular 和 liquid 两个入口
  entry: {
    regular: './app.js',
    liquid: './liquid.js'
  },
  # 配置输出文件的路径和文件名
  output: {
    path: dist,
    filename: '[name].js'
  },
  # 配置模块规则,包括处理 js 文件和 css 文件的规则
  module: {
    rules: [{
      test: /\.js$/,
      include: src,
      use: [{
        loader: 'babel-loader',
        options: {
          presets: [
            ['es2015', { modules: false }]
          ]
        }
      }]
    }, {
      test: /\.css$/,
      use: ExtractTextPlugin.extract({
        fallback: 'style-loader',
        use: ['css-loader']
      })
    },
    {
      test: /\.*(sass|scss)$/,
      use: ExtractTextPlugin.extract({
        fallback: 'style-loader',
        use: ['css-loader', 'sass-loader']
      })
    }]
  },
  # 配置插件,使用 ExtractTextPlugin 插件生成样式文件
  plugins: [
    new ExtractTextPlugin('styles.css')
  ]
};

# 导出配置对象
module.exports = config;

.\lucidrains\llama-qrlhf\llama_qrlhf\llama.py

import torch  # 导入 PyTorch 库
from torch.nn import Module, ModuleList  # 导入 PyTorch 中的 Module 和 ModuleList
from torch import nn, einsum, Tensor  # 导入 PyTorch 中的 nn、einsum 和 Tensor
import torch.nn.functional as F  # 导入 PyTorch 中的 nn.functional,并使用别名 F

from einops import rearrange, reduce  # 导入 einops 库中的 rearrange 和 reduce 函数
from einops.layers.torch import Rearrange  # 从 einops 库中导入 torch 版的 Rearrange 模块

# helpers

def exists(v):  # 定义一个函数 exists,用于判断变量是否存在
    return v is not None  # 返回变量是否不为 None 的布尔值

# norm

class RMSNorm(Module):  # 定义一个 RMSNorm 类,继承自 Module
    def __init__(self, dim):  # 初始化方法,接收维度参数 dim
        super().__init__()  # 调用父类的初始化方法
        self.scale = dim ** 0.5  # 计算缩放因子
        self.gamma = nn.Parameter(torch.ones(dim))  # 创建一个可学习的参数 gamma

    def forward(self, x):  # 前向传播方法,接收输入 x
        return F.normalize(x, dim=-1) * self.scale * self.gamma  # 对输入 x 进行归一化处理并乘以缩放因子和 gamma

# rotary

class RotaryEmbedding(Module):  # 定义一个 RotaryEmbedding 类,继承自 Module
    def __init__(self, dim, theta=10000):  # 初始化方法,接收维度参数 dim 和 theta,默认值为 10000
        super().__init__()  # 调用父类的初始化方法
        inv_freq = theta ** -(torch.arange(0, dim, 2).float() / dim)  # 计算频率的倒数
        self.register_buffer('inv_freq', inv_freq)  # 将频率的倒数注册为缓冲张量

    def forward(self, seq_len, device):  # 前向传播方法,接收序列长度和设备信息
        t = torch.arange(seq_len, device=device).type_as(self.inv_freq)  # 生成序列长度张量 t
        freqs = torch.einsum('i , j -> i j', t, self.inv_freq)  # 计算频率
        return torch.cat((freqs, freqs), dim=-1)  # 拼接频率张量并返回

def rotate_half(x):  # 定义一个函数 rotate_half,用于将输入张量 x 分成两部分并旋转
    x1, x2 = x.chunk(2, dim=-1)  # 将输入张量 x 按照最后一个维度分成两部分
    return torch.cat((-x2, x1), dim=-1)  # 将两部分旋转后拼接并返回

def apply_rotary_pos_emb(pos, t):  # 定义一个函数 apply_rotary_pos_emb,用于应用旋转位置编码
    return t * pos.cos() + rotate_half(t) * pos.sin()  # 返回应用旋转位置编码后的结果

# feedforward

class GEGLU(Module):  # 定义一个 GEGLU 类,继承自 Module
    def forward(self, x):  # 前向传播方法,接收输入 x
        x, gate = x.chunk(2, dim=-1)  # 将输入 x 按照最后一个维度分成两部分
        return F.gelu(gate) * x  # 对其中一部分应用 GELU 激活函数并返回乘积结果

def FeedForward(dim, mult=4):  # 定义一个 FeedForward 函数,用于创建前馈神经网络
    dim_hidden = int(dim * mult * 2 / 3)  # 计算隐藏层维度
    return nn.Sequential(  # 返回一个序列模块
        RMSNorm(dim),  # 添加 RMSNorm 模块
        nn.Linear(dim, dim_hidden * 2),  # 添加线性层
        GEGLU(),  # 添加 GEGLU 模块
        nn.Linear(dim_hidden, dim)  # 添加线性层
    )

# attention

class Attention(Module):  # 定义一个 Attention 类,继承自 Module
    def __init__(  # 初始化方法,接收维度参数 dim 和关键字参数
        self,
        dim,
        *,
        dim_head=64,
        heads=8
    ):
        super().__init__()  # 调用父类的初始化方法
        self.scale = dim_head ** -0.5  # 计算缩放因子
        dim_hidden = dim_head * heads  # 计算隐藏层维度

        self.to_qkv = nn.Sequential(  # 创建一个序列模块
            RMSNorm(dim),  # 添加 RMSNorm 模块
            nn.Linear(dim, dim_hidden * 3, bias=False),  # 添加线性层
            Rearrange('b n (qkv h d) -> qkv b h n d', h=heads, qkv=3)  # 重新排列张量维度
        )

        self.to_out = nn.Sequential(  # 创建一个序列模块
            Rearrange('b h n d -> b n (h d)'),  # 重新排列张量维度
            nn.Linear(dim_hidden, dim, bias=False)  # 添加线性层
        )

    def forward(self, x, rotary_emb=None):  # 前向传播方法,接收输入 x 和旋转位置编码
        q, k, v = self.to_qkv(x)  # 将输入 x 转换为查询、键、值

        if exists(rotary_emb):  # 如果旋转位置编码存在
            q, k = map(lambda t: apply_rotary_pos_emb(rotary_emb, t), (q, k))  # 应用旋转位置编码到查询和键

        q = q * self.scale  # 缩放查询
        sim = einsum('b h i d, b h j d -> b h i j', q, k)  # 计算相似度

        i, j = sim.shape[-2:]  # 获取相似度张量的形状
        causal_mask = torch.ones((i, j), device=x.device, dtype=torch.bool).triu(j - i + 1)  # 创建因果掩码
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)  # 对相似度张量应用掩码

        attn = sim.softmax(dim=-1)  # 对相似度张量进行 softmax 操作

        out = einsum('b h i j, b h j d -> b h i d', attn, v)  # 计算加权和

        return self.to_out(out)  # 返回输出结果

# Q head

class DuelingHead(Module):  # 定义一个 DuelingHead 类,继承自 Module
    def __init__(  # 初始化方法,接收关键字参数
        self,
        *,
        dim,
        num_tokens,
        expansion_factor=2,
    ):
        super().__init__()  # 调用父类的初始化方法
        dim_hidden = int(dim * expansion_factor)  # 计算隐藏层维度

        self.stem = nn.Sequential(  # 创建一个序列模块
            nn.Linear(dim, dim_hidden),  # 添加线性层
            nn.SiLU()  # 添加 SiLU 激活函数
        )

        self.to_values = nn.Sequential(  # 创建一个序列模块
            nn.Linear(dim_hidden, 1)  # 添加线性层
        )

        self.to_advantages = nn.Sequential(  # 创建一个序列模块
            nn.Linear(dim_hidden, num_tokens)  # 添加线性层
        )

    def forward(self, x):  # 前向传播方法,接收输入 x
        x = self.stem(x)  # 应用 stem 模块到输入 x

        advantages = self.to_advantages(x)  # 计算优势值
        advantages = advantages - reduce(advantages, '... a -> ... 1', 'mean')  # 计算优势值的平均值

        values = self.to_values(x)  # 计算值函数

        q_values = values + advantages  # 计算 Q 值
        return q_values  # 返回 Q 值

# llama

class Llama(Module):  # 定义一个 Llama 类,继承自 Module
    def __init__(  # 初始化方法,接收关键字参数
        self,
        *,
        num_tokens,
        dim,
        depth,
        dim_head=64,
        heads=8,
        ff_mult=4,
        dueling_q_head=False,
        dueling_q_head_expansion_factor=2
    # 初始化模型,继承父类的初始化方法
    ):
        super().__init__()

        # 创建 token embedding 层,将输入 token 映射为指定维度的向量
        self.token_emb = nn.Embedding(num_tokens, dim)
        # 创建旋转 embedding 层,用于在注意力机制中引入旋转
        self.rotary_emb = RotaryEmbedding(dim_head)

        # 创建多层 Transformer 模型
        self.layers = ModuleList([])

        # 循环创建指定层数的 Transformer 层
        for _ in range(depth):
            # 每层包含注意力机制和前馈神经网络
            self.layers.append(ModuleList([
                Attention(dim = dim, dim_head = dim_head, heads = heads),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

        # 创建最终的归一化层
        self.final_norm = RMSNorm(dim)

        # 创建输出层,将模型输出映射为预测的 token
        self.to_logits = nn.Linear(dim, num_tokens)

        # 如果使用 dueling q head,则创建 dueling 头部
        if dueling_q_head:
            self.to_q = DuelingHead(num_tokens = num_tokens, dim = dim, expansion_factor = dueling_q_head_expansion_factor)
        else:
            # 否则创建普通的线性层
            self.to_q = nn.Linear(dim, num_tokens)

    # 模型的前向传播方法
    def forward(
        self,
        x,
        return_q_values = False
    ):
        # 获取输入序列的长度和设备信息
        seq_len, device = x.shape[-1], x.device

        # 对输入序列进行 token embedding
        x = self.token_emb(x)

        # 创建旋转 embedding
        rotary_emb = self.rotary_emb(seq_len, device = device)

        # 遍历每一层 Transformer
        for attn, ff in self.layers:
            # 执行注意力机制和前馈神经网络
            x = attn(x, rotary_emb = rotary_emb) + x
            x = ff(x) + x

        # 对输出进行最终的归一化
        embed = self.final_norm(x)
        # 将归一化后的输出映射为预测的 token
        logits = self.to_logits(embed)

        # 如果需要返回 Q 值,则计算 Q 值并返回
        if not return_q_values:
            return logits

        return logits, self.to_q(embed)

.\lucidrains\llama-qrlhf\llama_qrlhf\llama_qrlhf.py

import torch
from torch.nn import Module
from torch.utils.data import Dataset
from torch import nn, einsum, Tensor
import torch.nn.functional as F

from einops import rearrange, repeat

from ema_pytorch import EMA

from beartype import beartype
from beartype.typing import Optional

from torchtyping import TensorType

from accelerate import Accelerator

# helper functions

# 检查变量是否存在
def exists(v):
    return v is not None

# tensor helpers

# 从输入的张量中选择指定索引的值
def batch_select_indices(t, indices):
    indices = rearrange(indices, '... -> ... 1')
    selected = t.gather(-1, indices)
    return rearrange(selected, '... 1 -> ...')

# Q functions

# 基于自回归的 Q 学习
def autoregressive_q_learn(
    model:          Module,
    ema_model:      Module,
    states:         TensorType['b', 'n', int],     # 包含提示和生成序列的整个序列
    prompt_len:     TensorType['b', int],          # 前导提示序列的长度
    next_states:    TensorType['b', int],          # 选择的动作成为下一个状态
    rewards:        TensorType['b', 'n', float],   # 奖励可以在最后给出,也可以在中间给出
    eos_id:         Optional[int] = None,          # 从 <eos> 标记 id 计算完成状态
    discount_gamma: float = 0.998                  # 奖励折扣因子,鼓励生成答案的简洁性
) -> TensorType[()]:
    """
    einops

    b - batch
    n - sequence len
    """
    seq_len, device = states.shape[-1], states.device

    # 因为希腊字母的 Unicode 看起来很好

    γ = discount_gamma

    # 获取每个动作的预测 Q 值

    q_pred_all_actions = model(states)
    q_pred = batch_select_indices(q_pred_all_actions, actions)

    # 将下一个状态附加到当前状态,以获取目标 Q

    q_target_input = pack([states[:, 1:], next_state], 'b *')

    # 获取目标 Q

    q_target = ema_model(q_target_input)
    q_target = q_target_all_actions.max(dim = -1).values

    # 第一个完成标志之后的任何内容都将被视为终止状态

    if exists(eos_id):
        done = states == eos_id
        dones = dones.cumsum(dim = -1) > 0
        dones = F.pad(dones, (1, -1), value = False)

        not_terminal = (~dones).float()

        # 奖励不应在终止步骤及之后给出

        rewards = rewards * not_terminal
        q_target = q_target.masked_fill(dones, 0.)

    # 论文的主要贡献是以下逻辑
    # 第 4.1 节 - 公式 1

    # 在没有给出奖励的情况下,时间 t 的 Q 预测是 t + 1 的 max(Q target)

    losses_without_rewards = F.mse_loss(q_pred, q_target, reduction = 'none')

    # 处理给出奖励的时间步骤。���典的贝尔曼方程

    q_target_with_rewards = rewards + γ * q_target

    losses_with_rewards = F.mse_loss(q_pred, q_target_with_rewards, reduction = 'none')

    # 最终损失

    losses = torch.where(
        rewards > 0.,
        losses_with_reward,
        losses_without_rewards
    )

    # 执行掩码平均值
    # 仅考虑从提示的最后一个标记开始的 'q logits' 作为 '动作'

    is_action_mask = torch.arange(seq_len, device = device) > rearrange(prompt_len - 1, 'b -> b 1')
    losses = losses[is_action_mask]

    return losses.mean()

# 保守正则化损失
def conservative_regularization_loss(
    q_values:           TensorType['b', 'n', 'a', float],
    states_and_actions: TensorType['b', 'n', int],
    action_mask:        TensorType['b', 'n', bool],
    reward_min:         float = 0.
) -> TensorType[()]:
    batch, seq_len, num_actions, device = *q_values.shape, q_values.device
    non_dataset_actions = torch.arange(num_actions, device = device) == rearrange(states_and_actions, '... -> ... 1')

    q_values = q_values[~non_dataset_actions]
    q_values = rearrange(q_values, '(b n a) -> b n a', b = batch, n = seq_len)
    # 从Q值中选择动作掩码对应的值
    q_values = q_values[action_mask]

    # 创建一个包含指定值的张量,用于计算奖励的最小值
    reward_min = torch.full((), reward_min, device=device) * seq_len

    # 使用均方误差损失函数计算Q值和奖励最小值之间的损失
    return F.mse_loss(q_values, reward_min)
# 主要类

# 定义 QRLHF 类,继承自 Module 类
class QRLHF(Module):
    # 初始化方法,接受模型、数据集、加速参数和指数移动平均参数
    @beartype
    def __init__(
        self,
        model:   Module,  # 模型对象
        dataset: Dataset,  # 数据集对象
        accelerate_kwargs: dict = dict(),  # 加速参数,默认为空字典
        ema_kwargs: dict = dict(  # 指数移动平均参数,默认包含 beta=0.99
            beta = 0.99
        )
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 将传入的模型赋值给 lm 属性
        self.lm = model
        # 使用传入的模型创建 EMA 对象,并赋值给 lm_target 属性
        self.lm_target = EMA(model, **ema_kwargs)

    # 前向传播方法,抛出未实现错误
    def forward(self):
        raise NotImplementedError

.\lucidrains\llama-qrlhf\llama_qrlhf\__init__.py

# 从 llama_qrlhf 模块中导入 QRLHF 类
from llama_qrlhf.llama_qrlhf import QRLHF

Llama - QRLHF (wip)

Implementation of the Llama (or any language model) architecture with RLHF + Q-learning.

This is experimental / independent open research, built off nothing but speculation. But I'll throw some of my brain cycles at the problem in the coming month, just in case the rumors have any basis. Anything you PhD students can get working is up for grabs.

Will start off by adapting the autoregressive discrete Q-learning formulation in the cited paper below and run a few experiments on arithmetic, using a symbolic solver as reward generator.

Yannic Kilcher's educational Q-learning video

Citations

@inproceedings{qtransformer,
    title   = {Q-Transformer: Scalable Offline Reinforcement Learning via Autoregressive Q-Functions},
    authors = {Yevgen Chebotar and Quan Vuong and Alex Irpan and Karol Hausman and Fei Xia and Yao Lu and Aviral Kumar and Tianhe Yu and Alexander Herzog and Karl Pertsch and Keerthana Gopalakrishnan and Julian Ibarz and Ofir Nachum and Sumedh Sontakke and Grecia Salazar and Huong T Tran and Jodilyn Peralta and Clayton Tan and Deeksha Manjunath and Jaspiar Singht and Brianna Zitkovich and Tomas Jackson and Kanishka Rao and Chelsea Finn and Sergey Levine},
    booktitle = {7th Annual Conference on Robot Learning},
    year   = {2023}
}
@inproceedings{Wang2015DuelingNA,
    title   = {Dueling Network Architectures for Deep Reinforcement Learning},
    author  = {Ziyun Wang and Tom Schaul and Matteo Hessel and H. V. Hasselt and Marc Lanctot and Nando de Freitas},
    booktitle = {International Conference on Machine Learning},
    year    = {2015},
    url     = {https://api.semanticscholar.org/CorpusID:5389801}
}

.\lucidrains\llama-qrlhf\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'llama-qrlhf', # 包的名称
  packages = find_packages(exclude=[]), # 查找所有包
  version = '0.0.1', # 版本号
  license='MIT', # 许可证
  description = 'Experimental Q-RLHF applied to Language Modeling. Made compatible with Llama of course', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  long_description_content_type = 'text/markdown', # 长描述内容类型
  url = 'https://github.com/lucidrains/llama-qrlhf', # 项目链接
  keywords = [
    'artificial intelligence',
    'deep learning',
    'reinforcement learning with human feedback',
    'q learning',
  ], # 关键词
  install_requires = [
    'accelerate',
    'beartype',
    'ema-pytorch',
    'einops>=0.7.0',
    'torch>=2.0'
  ], # 安装依赖
  classifiers=[
    'Development Status :: 4 - Beta', # 开发状态
    'Intended Audience :: Developers', # 目标受众
    'Topic :: Scientific/Engineering :: Artificial Intelligence', # 主题
    'License :: OSI Approved :: MIT License', # 许可证
    'Programming Language :: Python :: 3.6', # 编程语言
  ],
)

Llama2 - Nim (wip)

Basically a transcription of Andrej Karpathy's Llama.c to Nim. Just to gain more experience with Nim.

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

.\lucidrains\local-attention\local_attention\local_attention.py

# 导入数学库
import math

# 导入 torch 库
import torch
from torch import nn, einsum
import torch.nn.functional as F

# 导入 einops 库中的函数
from einops import rearrange, repeat, pack, unpack

# 导入 rotary 模块中的函数
from local_attention.rotary import SinusoidalEmbeddings, apply_rotary_pos_emb

# 常量定义
TOKEN_SELF_ATTN_VALUE = -5e4

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(value, d):
    return d if not exists(value) else value

# 返回张量的设备和数据类型
def to(t):
    return {'device': t.device, 'dtype': t.dtype}

# 返回张量的最大负值
def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max

# 对张量进行 L2 归一化
def l2norm(tensor):
    dtype = tensor.dtype
    normed = F.normalize(tensor, dim = -1)
    return normed.type(dtype)

# 将张量填充到指定的倍数
def pad_to_multiple(tensor, multiple, dim=-1, value=0):
    seqlen = tensor.shape[dim]
    m = seqlen / multiple
    if m.is_integer():
        return False, tensor
    remainder = math.ceil(m) * multiple - seqlen
    pad_offset = (0,) * (-1 - dim) * 2
    return True, F.pad(tensor, (*pad_offset, 0, remainder), value = value)

# 在张量周围添加填充
def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2):
    t = x.shape[1]
    dims = (len(x.shape) - dim) * (0, 0)
    padded_x = F.pad(x, (*dims, backward, forward), value = pad_value)
    tensors = [padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)]
    return torch.cat(tensors, dim = dim)

# 主类

class LocalAttention(nn.Module):
    def __init__(
        self,
        window_size,
        causal = False,
        look_backward = 1,
        look_forward = None,
        dropout = 0.,
        shared_qk = False,
        rel_pos_emb_config = None,
        dim = None,
        autopad = False,
        exact_windowsize = False,
        scale = None,
        use_rotary_pos_emb = True,
        use_xpos = False,
        xpos_scale_base = None
    ):
        super().__init__()
        look_forward = default(look_forward, 0 if causal else 1)
        assert not (causal and look_forward > 0), 'you cannot look forward if causal'

        self.scale = scale

        self.window_size = window_size
        self.autopad = autopad
        self.exact_windowsize = exact_windowsize

        self.causal = causal

        self.look_backward = look_backward
        self.look_forward = look_forward

        self.dropout = nn.Dropout(dropout)

        self.shared_qk = shared_qk

        # 相对位置编码

        self.rel_pos = None
        self.use_xpos = use_xpos

        if use_rotary_pos_emb and (exists(rel_pos_emb_config) or exists(dim)):  # 向后兼容旧的 `rel_pos_emb_config` 参数
            if exists(rel_pos_emb_config):
                dim = rel_pos_emb_config[0]

            self.rel_pos = SinusoidalEmbeddings(
                dim,
                use_xpos = use_xpos,
                scale_base = default(xpos_scale_base, window_size // 2)
            )

    def forward(
        self,
        q, k, v,
        mask = None,
        input_mask = None,
        attn_bias = None,
        window_size = None

.\lucidrains\local-attention\local_attention\rotary.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum

# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 定义一个函数,用于检查变量是否存在
def exists(val):
    return val is not None

# 定义一个继承自 nn.Module 的类 SinusoidalEmbeddings
class SinusoidalEmbeddings(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        scale_base = None,
        use_xpos = False
    ):
        super().__init__()
        # 计算频率的倒数
        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        # 将频率的倒数作为缓冲区注册到模型中
        self.register_buffer('inv_freq', inv_freq)

        # xpos 相关

        # 是否使用 xpos
        self.use_xpos = use_xpos
        # 缩放基数
        self.scale_base = scale_base

        # 断言,如果使用 xpos,则必须定义缩放基数
        assert not (use_xpos and not exists(scale_base)), 'scale base must be defined if using xpos'

        # 计算缩放值
        scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
        # 将缩放值作为缓冲区注册到模型中,不持久化
        self.register_buffer('scale', scale, persistent = False)

    # 前向传播函数
    def forward(self, x):
        # 获取序列长度和设备信息
        seq_len, device = x.shape[-2], x.device

        # 生成时间步长
        t = torch.arange(seq_len, device = x.device).type_as(self.inv_freq)
        # 计算频率
        freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
        freqs =  torch.cat((freqs, freqs), dim = -1)

        # 如果不使用 xpos,则返回频率和单位矩阵
        if not self.use_xpos:
            return freqs, torch.ones(1, device = device)

        # 计算幂次
        power = (t - (seq_len // 2)) / self.scale_base
        # 计算缩放值
        scale = self.scale ** rearrange(power, 'n -> n 1')
        scale = torch.cat((scale, scale), dim = -1)

        return freqs, scale

# 定义一个函数,用于将输入向量旋转 180 度
def rotate_half(x):
    x = rearrange(x, 'b ... (r d) -> b ... r d', r = 2)
    x1, x2 = x.unbind(dim = -2)
    return torch.cat((-x2, x1), dim = -1)

# 应用旋转位置嵌入
def apply_rotary_pos_emb(q, k, freqs, scale = 1):
    # 获取查询向量的长度
    q_len = q.shape[-2]
    # 获取查询向量的频率
    q_freqs = freqs[..., -q_len:, :]

    # 计算缩放的倒数
    inv_scale = scale ** -1

    # 如果缩放的维度为 2,则截取对应维度
    if scale.ndim == 2:
        scale = scale[-q_len:, :]

    # 对查询向量��用旋转位置嵌入
    q = (q * q_freqs.cos() * scale) + (rotate_half(q) * q_freqs.sin() * scale)
    k = (k * freqs.cos() * inv_scale) + (rotate_half(k) * freqs.sin() * inv_scale)
    return q, k

.\lucidrains\local-attention\local_attention\transformer.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 从 local_attention 包中导入 LocalAttention 类
from local_attention.local_attention import LocalAttention

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 对张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# 评估装饰器函数
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 采样函数

# 返回 logits 中大于阈值的前 k 个值
def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 多头注意力机制

class LocalMHA(nn.Module):
    def __init__(
        self,
        *,
        dim,
        window_size,
        dim_head = 64,
        heads = 8,
        dropout = 0.,
        causal = False,
        prenorm = False,
        qk_rmsnorm = False,
        qk_scale = 8,
        use_xpos = False,
        xpos_scale_base = None,
        exact_windowsize = None,
        gate_values_per_head = False,
        **kwargs
    ):
        super().__init__()        
        inner_dim = dim_head * heads

        # 如果 prenorm 为 True,则使用 LayerNorm 进行归一化
        self.norm = nn.LayerNorm(dim) if prenorm else None

        self.heads = heads
        # 将输入映射到查询、键、值空间
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.qk_rmsnorm = qk_rmsnorm

        if qk_rmsnorm:
            self.q_scale = nn.Parameter(torch.ones(dim_head))
            self.k_scale = nn.Parameter(torch.ones(dim_head))

        # 使用 LocalAttention 进行局部注意力计算
        self.attn_fn = LocalAttention(
            dim = dim_head,
            window_size = window_size,
            causal = causal,
            autopad = True,
            scale = (qk_scale if qk_rmsnorm else None),
            exact_windowsize = default(exact_windowsize, True),
            use_xpos = use_xpos,
            xpos_scale_base = xpos_scale_base,
            **kwargs
        )

        self.to_v_gate = None

        if gate_values_per_head:
            self.to_v_gate = nn.Sequential(
                nn.Linear(dim, heads)
            )

        # 将输出映射回原始维度
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x, mask = None, attn_bias = None):
        if exists(self.norm):
            x = self.norm(x)

        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) 

        if self.qk_rmsnorm:
            q, k = map(l2norm, (q, k))
            q = q * self.q_scale
            k = k * self.k_scale

        out = self.attn_fn(q, k, v, mask = mask, attn_bias = attn_bias)

        if exists(self.to_v_gate):
            gates = self.to_v_gate(x)
            gates = rearrange(gates, 'b n h -> b h n 1')
            out = out * gates.sigmoid()

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 前馈网络

class GEGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
        return x * F.gelu(gate)

# 创建前馈网络
def FeedForward(dim, mult = 4, dropout = 0.):
    inner_dim = int(dim * mult * 2 / 3)

    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim * 2, bias = False),
        GEGLU(),
        nn.Dropout(dropout),
        nn.Linear(inner_dim, dim, bias = False)
    )

# 动态位置偏置

class DynamicPositionBias(nn.Module):
    def __init__(
        self,
        dim,
        heads
    ):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(1, dim),
            nn.SiLU(),
            nn.Linear(dim, dim),
            nn.SiLU(),
            nn.Linear(dim, heads)
        )

    @property
    def device(self):
        return next(self.parameters()).device
    # 定义一个前向传播函数,接受输入参数 i 和 j
    def forward(self, i, j):
        # 获取设备信息
        device = self.device
        # 断言 j 大于等于 i
        assert j >= i

        # 创建一个相对距离张量,从 i 到 j,数据类型为浮点数,使用指定设备
        rel_dist = torch.arange(j, dtype=torch.float, device=device)
        # 使用 MLP 模型处理重新排列后的相对距离张量,得到偏置
        bias = self.mlp(rearrange(rel_dist, '... -> ... 1'))

        # 创建从 i 到 j-1 的序列张量,使用指定设备
        i_seq = torch.arange(j - i, j, device=device)
        # 创建从 0 到 j-1 的序列张量,使用指定设备
        j_seq = torch.arange(j, device=device)

        # 计算相对距离的索引,取绝对值
        rel_dist_indices = (rearrange(i_seq, 'i -> i 1') - rearrange(j_seq, 'j -> 1 j')).abs()

        # 重新排列偏置张量,根据相对距离索引,维度顺序为 h i j
        bias = rearrange(bias[rel_dist_indices], 'i j h -> h i j')
        # 返回处理后的偏置张量
        return bias
# 主要的转换器类

class LocalTransformer(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,  # 标记的数量
        max_seq_len,  # 最大序列长度
        dim,  # 维度
        depth,  # 深度
        causal = True,  # 是否使用因果注意力
        local_attn_window_size = 512,  # 本地注意力窗口大小
        dim_head = 64,  # 头部维度
        heads = 8,  # 头部数量
        ff_mult = 4,  # FeedForward 层的倍数
        attn_dropout = 0.,  # 注意力层的丢弃率
        ff_dropout = 0.,  # FeedForward 层的丢弃率
        ignore_index = -1,  # 忽略的索引
        use_xpos = False,  # 是否使用位置编码
        xpos_scale_base = None,  # 位置编码的缩放基数
        use_dynamic_pos_bias = False,  # 是否使用动态位置偏置
        **kwargs
    ):
        super().__init__()
        self.token_emb = nn.Embedding(num_tokens, dim)  # 标记嵌入层
        self.pos_emb = nn.Embedding(max_seq_len, dim)  # 位置嵌入层

        self.max_seq_len = max_seq_len  # 最大序列长度
        self.layers = nn.ModuleList([])  # 层列表

        self.local_attn_window_size = local_attn_window_size  # 本地注意力窗口大小
        self.dynamic_pos_bias = None
        if use_dynamic_pos_bias:
            self.dynamic_pos_bias = DynamicPositionBias(dim = dim // 2, heads = heads)  # 动态位置偏置

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                LocalMHA(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = causal, window_size = local_attn_window_size, use_xpos = use_xpos, xpos_scale_base = xpos_scale_base, use_rotary_pos_emb = not use_dynamic_pos_bias, prenorm = True, **kwargs),
                FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
            ]))  # 添加多层局部多头注意力和前馈网络

        self.ignore_index = ignore_index  # 忽略的索引
        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),  # 层归一化
            nn.Linear(dim, num_tokens, bias = False)  # 线性层
        )

    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        prime,  # 初始序列
        seq_len,  # 生成序列的长度
        temperature = 1.,  # 温度参数
        filter_thres = 0.9,  # 过滤阈值
        **kwargs
    ):
        n, device = prime.shape[1], prime.device

        out = prime

        for _ in range(seq_len):
            logits = self.forward(out[:, -self.max_seq_len:], **kwargs)  # 前向传播获取 logits
            filtered_logits = top_k(logits[:, -1], thres = filter_thres)  # 获取 top-k logits
            probs = F.softmax(filtered_logits / temperature, dim = -1)  # softmax 计算概率
            sampled = torch.multinomial(probs, 1)  # 多项式采样
            out = torch.cat((out, sampled), dim = -1)  # 将采样结果拼接到输出序列

        return out[:, n:]  # 返回生成的序列

    def forward(self, x, mask = None, return_loss = False):
        if return_loss:
            x, labels = x[:, :-1], x[:, 1:]  # 获取输入和标签序列

        n, device = x.shape[1], x.device
        x = self.token_emb(x)  # 标记嵌入

        assert n <= self.max_seq_len
        x = x + self.pos_emb(torch.arange(n, device = device))  # 添加位置编码

        # 动态位置偏置

        attn_bias = None
        if exists(self.dynamic_pos_bias):
            w = self.local_attn_window_size
            attn_bias = self.dynamic_pos_bias(w, w * 2)  # 计算注意力偏置

        # 通过层

        for attn, ff in self.layers:
            x = attn(x, mask = mask, attn_bias = attn_bias) + x  # 多头注意力层
            x = ff(x) + x  # 前馈网络

        logits = self.to_logits(x)  # 线性层得到 logits

        if not return_loss:
            return logits

        logits = rearrange(logits, 'b n c -> b c n')  # 重新排列 logits
        loss = F.cross_entropy(logits, labels, ignore_index = self.ignore_index)  # 计算交叉熵损失
        return loss  # 返回损失

.\lucidrains\local-attention\local_attention\__init__.py

# 从 local_attention 包中导入 LocalAttention 类
from local_attention.local_attention import LocalAttention
# 从 local_attention 包中导入 LocalTransformer、LocalMHA 和 DynamicPositionBias 类
from local_attention.transformer import LocalTransformer, LocalMHA, DynamicPositionBias

Local attention

An implementation of local windowed attention, which sets an incredibly strong baseline for language modeling. It is becoming apparent that a transformer needs local attention in the bottom layers, with the top layers reserved for global attention to integrate the findings of previous layers. This repository makes it easy to immediately employ local window attention.

This code has been battletested in multiple repositories already, alongside different implementations of sparse long-range attention.

Install

$ pip install local-attention

Usage

import torch
from local_attention import LocalAttention

q = torch.randn(2, 8, 2048, 64)
k = torch.randn(2, 8, 2048, 64)
v = torch.randn(2, 8, 2048, 64)

attn = LocalAttention(
    dim = 64,                # dimension of each head (you need to pass this in for relative positional encoding)
    window_size = 512,       # window size. 512 is optimal, but 256 or 128 yields good enough results
    causal = True,           # auto-regressive or not
    look_backward = 1,       # each window looks at the window before
    look_forward = 0,        # for non-auto-regressive case, will default to 1, so each window looks at the window before and after it
    dropout = 0.1,           # post-attention dropout
    exact_windowsize = False # if this is set to true, in the causal setting, each query will see at maximum the number of keys equal to the window size
)

mask = torch.ones(2, 2048).bool()
out = attn(q, k, v, mask = mask) # (2, 8, 2048, 64)

This library also allows for local attention in the setting of shared query/key space (Reformer architecture). The normalization of the keys, as well as the masking of tokens to itself, will be taken care of.

import torch
from local_attention import LocalAttention

qk = torch.randn(2, 8, 2048, 64)
v  = torch.randn(2, 8, 2048, 64)

attn = LocalAttention(
    dim = 64,
    window_size = 512,
    shared_qk = True,
    causal = True
)

mask = torch.ones(2, 2048).bool()
out = attn(qk, qk, v, mask = mask) # (2, 8, 2048, 64)

If you wish for the module to automagically pad your query / key / values as well as the mask, simply set the autopad keyword to True

import torch
from local_attention import LocalAttention

q = torch.randn(8, 2057, 64)
k = torch.randn(8, 2057, 64)
v = torch.randn(8, 2057, 64)

attn = LocalAttention(
    window_size = 512,
    causal = True,
    autopad = True      # auto pads both inputs and mask, then truncates output appropriately
)

mask = torch.ones(1, 2057).bool()
out = attn(q, k, v, mask = mask) # (8, 2057, 64)

Local Attention Transformer

A full local attention transformer

import torch
from local_attention import LocalTransformer

model = LocalTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 6,
    max_seq_len = 8192,
    causal = True,
    local_attn_window_size = 256
).cuda()

x = torch.randint(0, 256, (1, 8192)).cuda()

logits = model(x) # (1, 8192, 256)

Enwik8 at 4096

window size of 256, lookback of 1, total receptive field of 512

$ python train.py

Citation

@inproceedings{rae-razavi-2020-transformers,
    title   = "Do Transformers Need Deep Long-Range Memory?",
    author  = "Rae, Jack  and Razavi, Ali",
    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
    month   = jul,
    year    = "2020",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url     = "https://www.aclweb.org/anthology/2020.acl-main.672"
}
@misc{roy*2020efficient,
    title   = {Efficient Content-Based Sparse Attention with Routing Transformers},
    author  = {Aurko Roy* and Mohammad Taghi Saffar* and David Grangier and Ashish Vaswani},
    year    = {2020},
    url     = {https://arxiv.org/pdf/2003.05997.pdf}
}
@misc{beltagy2020longformer,
    title   = {Longformer: The Long-Document Transformer},
    author  = {Iz Beltagy and Matthew E. Peters and Arman Cohan},
    year    = {2020},
    eprint  = {2004.05150},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@inproceedings{Sun2022ALT,
    title     = {A Length-Extrapolatable Transformer},
    author    = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},
    year      = {2022}
}
@article{Bondarenko2023QuantizableTR,
    title   = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
    author  = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2306.12929},
    url     = {https://api.semanticscholar.org/CorpusID:259224568}
}

.\lucidrains\local-attention\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'local-attention',  # 包的名称
  packages = find_packages(),  # 查找并包含所有包
  version = '1.9.0',  # 版本号
  license='MIT',  # 许可证
  description = 'Local attention, window with lookback, for language modeling',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/local-attention',  # 项目链接
  keywords = [
    'transformers',  # 关键词:transformers
    'attention',  # 关键词:attention
    'artificial intelligence'  # 关键词:artificial intelligence
  ],
  install_requires=[
    'einops>=0.6.0',  # 安装所需的依赖项:einops>=0.6.0
    'torch'  # 安装所需的依赖项:torch
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类器:开发状态为Beta
    'Intended Audience :: Developers',  # 分类器:面向的受众为开发者
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类器:主题为科学/工程和人工智能
    'License :: OSI Approved :: MIT License',  # 分类器:许可证为MIT
    'Programming Language :: Python :: 3.6',  # 分类器:编程语言为Python 3.6
  ],
)

.\lucidrains\local-attention\train.py

# 导入所需的库
import random
import tqdm
import gzip
import numpy as np

import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from local_attention import LocalTransformer

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 2048
SEQ_LEN = 2048

# 定义辅助函数

# 将 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 将 tokens 解码为字符串
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 实例化类似 GPT 的解码器模型
model = LocalTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 6,
    causal = True,
    local_attn_window_size = 256,
    max_seq_len = SEQ_LEN,
    use_dynamic_pos_bias = True
).cuda()

# 准备 enwik8 数据

with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 定义优化器
optim = Adam(model.parameters(), lr=LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader), return_loss = True)
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader), return_loss = True)
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        print(output_str)

.\lucidrains\local-attention-flax\local_attention_flax\local_attention_flax.py

# 导入必要的库
import flax.linen as nn
from jax import numpy as np
from einops import rearrange

# 定义全局变量,用于掩码操作
ATTN_MASK_VALUE = -1e10

# 定义一个名为LocalAttention的类,继承自nn.Module
class LocalAttention(nn.Module):
    # 初始化函数,接受dim(维度)、window_size(窗口大小)、heads(头数,默认为8)、dim_head(每个头的维度,默认为64)
    dim: int
    window_size: int
    heads: int = 8
    dim_head: int = 64

    # 定义__call__方法,用于实现类的调用
    @nn.compact
    def __call__(self, x):
        # 获取输入张量x的维度信息
        n, h, dim_head, wsz = x.shape[0], self.heads, self.dim_head, self.window_size
        # 断言,确保序列长度必须能被窗口大小整除
        assert (n % wsz) == 0, 'sequence length must be divisible by the window size'
        # 计算缩放因子
        scale = dim_head ** -0.5
        # 计算窗口数量
        window = n // wsz

        # 将输入张量x通过全连接层映射为qkv
        qkv = nn.Dense(features = 3 * h * dim_head, use_bias = False)(x)
        # 将qkv分割为q、k、v
        q, k, v = np.split(qkv, 3, axis = -1)
        # 重排q、k、v的维度
        q, k, v = map(lambda t: rearrange(t, '(w n) (h d) -> h w n d', w = window, h = h), (q, k, v))

        # 对k、v进行填充
        k, v = map(lambda t: np.pad(t, ((0, 0), (1, 0), (0, 0), (0, 0)), constant_values = 0.), (k ,v))
        # 对k、v进行拼接
        k, v = map(lambda t: np.concatenate((t[:, :-1], t[:, 1:]), axis = 2), (k, v))

        # 计算注意力分数
        sim = np.einsum('h w i d, h w j d -> h w i j', q, k) * scale

        # 创建掩码
        mask = np.tril(np.ones((wsz, wsz * 2)), wsz)
        # 将掩码应用到注意力分数上
        sim = np.where(mask, sim, ATTN_MASK_VALUE)

        # 计算注意力权重
        attn = nn.softmax(sim, axis = -1)
        # 计算输出张量
        out = np.einsum('h w i j, h w j d -> h w i d', attn, v)
        # 重排输出张量的维度
        out = rearrange(out, 'h w n d -> (w n) (h d)')
        # 通过全连接层映射输出张量
        out =  nn.Dense(features = self.dim)(out)
        # 返回输出张量
        return out

.\lucidrains\local-attention-flax\local_attention_flax\__init__.py

# 从 local_attention_flax 模块中导入 LocalAttention 类
from local_attention_flax.local_attention_flax import LocalAttention

Local Attention - Flax

Autoregressive Local Attention - Flax module for Jax

Install

$ pip install local-attention-flax

Usage

from jax import random
from local_attention_flax import LocalAttention

attn = LocalAttention(
    dim = 256,
    dim_head = 64,
    heads = 8,
    window_size = 128
)

key = random.PRNGKey(0)
x = random.normal(key, (512, 256))

params = attn.init(key, x)
out = attn.apply(params, x)  # (512, 256)

.\lucidrains\local-attention-flax\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
    name="local-attention-flax",  # 包的名称
    packages=find_packages(),  # 查找并包含所有包
    version="0.0.2",  # 版本号
    license="MIT",  # 许可证
    description="Local Attention - Flax Module in Jax",  # 描述
    author="Phil Wang",  # 作者
    author_email="",  # 作者邮箱
    url="https://github.com/lucidrains/local-attention-flax",  # 项目链接
    keywords=[  # 关键词列表
        "artificial intelligence",
        "deep learning",
        "attention mechanism",
        "jax"
    ],
    install_requires=[  # 安装依赖
        "einops>=0.3",
        "flax",
        "jax",
        "jaxlib"
    ],
    classifiers=[  # 分类器列表
        "Development Status :: 4 - Beta",
        "Intended Audience :: Developers",
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
        "License :: OSI Approved :: MIT License",
        "Programming Language :: Python :: 3.6",
    ],
)

.\lucidrains\logavgexp-torch\logavgexp_pytorch\logavgexp_pytorch.py

import math
from functools import partial

import torch
from torch import nn
import torch.nn.functional as F

from einops import rearrange
from unfoldNd import unfoldNd

# helper functions

# 检查变量是否存在
def exists(t):
    return t is not None

# 对张量取对数
def log(t, eps = 1e-20):
    return torch.log(t + eps)

# 将输入转换为元组
def cast_tuple(t, length = 1):
    return t if isinstance(t, tuple) else ((t,) * length)

# 计算卷积输出形状
def calc_conv_output(shape, kernel_size, padding, stride):
    return tuple(map(lambda x: int((x[0] - x[1] + 2 * x[2]) / x[3] + 1), zip(shape, kernel_size, padding, stride))

# main function

# 对输入张量进行 logavgexp 操作
def logavgexp(
    t,
    mask = None,
    dim = -1,
    eps = 1e-20,
    temp = 0.01,
    keepdim = False
):
    if exists(mask):
        mask_value = -torch.finfo(t.dtype).max
        t = t.masked_fill(~mask, mask_value)
        n = mask.sum(dim = dim)
        norm = torch.log(n)
    else:
        n = t.shape[dim]
        norm = math.log(n)

    t = t / temp
    max_t = t.amax(dim = dim).detach()
    t_exp = (t - max_t.unsqueeze(dim)).exp()
    avg_exp = t_exp.sum(dim = dim).clamp(min = eps) / n
    out = log(avg_exp, eps = eps) + max_t - norm
    out = out * temp

    out = out.unsqueeze(dim) if keepdim else out
    return out

# learned temperature - logavgexp class

# LogAvgExp 类,用于 logavgexp 操作
class LogAvgExp(nn.Module):
    def __init__(
        self,
        dim = -1,
        eps = 1e-20,
        temp = 0.01,
        keepdim = False,
        learned_temp = False
    ):
        super().__init__()
        assert temp >= 0 and temp <= 1., 'temperature must be between 0 and 1'

        self.learned_temp = learned_temp

        if learned_temp:
            self.temp = nn.Parameter(torch.ones((1,)) * math.log(temp))
        else:
            self.temp = temp

        self.dim = dim
        self.keepdim = keepdim

    def forward(self, x, mask = None, eps = 1e-8):
        if not self.learned_temp:
            temp = self.temp
        else:
            temp = self.temp.exp().clamp(min = eps)

        return logavgexp(
            x,
            mask = mask,
            dim = self.dim,
            temp = temp,
            keepdim = self.keepdim
        )

# logavgexp 2d

# LogAvgExp2D 类,用于 2D logavgexp 操作
class LogAvgExp2D(nn.Module):
    def __init__(
        self,
        kernel_size,
        *,
        padding = 0,
        stride = 1,
        temp = 0.01,
        learned_temp = True,
        eps = 1e-20,
        **kwargs
    ):
        super().__init__()
        self.padding = cast_tuple(padding, 2)
        self.stride = cast_tuple(stride, 2)
        self.kernel_size = cast_tuple(kernel_size, 2)

        self.unfold = nn.Unfold(self.kernel_size, padding = self.padding, stride = self.stride)
        self.logavgexp = LogAvgExp(dim = -1, eps = eps, learned_temp = learned_temp, temp = temp)

    def forward(self, x):
        """
        b - batch
        c - channels
        h - height
        w - width
        j - reducing dimension
        """

        b, c, h, w = x.shape
        out_h, out_w = calc_conv_output((h, w), self.kernel_size, self.padding, self.stride)

        # calculate mask for padding, if needed

        mask = None
        if any([i > 0 for i in self.padding]):
            mask = torch.ones((b, 1, h, w), device = x.device)
            mask = self.unfold(mask)
            mask = rearrange(mask, 'b j (h w) -> b 1 h w j', h = out_h, w = out_w)
            mask = mask == 1.

        x = self.unfold(x)
        x = rearrange(x, 'b (c j) (h w) -> b c h w j', h = out_h, w = out_w, c = c)
        return self.logavgexp(x, mask = mask)

# logavgexp 3d

# LogAvgExp3D 类,用于 3D logavgexp 操作
class LogAvgExp3D(nn.Module):
    def __init__(
        self,
        kernel_size,
        *,
        padding = 0,
        stride = 1,
        temp = 0.01,
        learned_temp = True,
        eps = 1e-20,
        **kwargs
    # 初始化函数,设置填充、步幅和卷积核大小
    def __init__(
        super().__init__()
        # 将填充、步幅和卷积核大小转换为元组形式
        self.padding = cast_tuple(padding, 3)
        self.stride = cast_tuple(stride, 3)
        self.kernel_size = cast_tuple(kernel_size, 3)

        # 部分应用 unfoldNd 函数,设置卷积核大小、填充和步幅
        self.unfold = partial(unfoldNd, kernel_size = self.kernel_size, padding = self.padding, stride = self.stride)
        # 初始化 LogAvgExp 函数
        self.logavgexp = LogAvgExp(dim = -1, eps = eps, learned_temp = learned_temp, temp = temp)

    # 前向传播函数
    def forward(self, x):
        """
        b - batch
        c - channels
        f - depth
        h - height
        w - width
        j - reducing dimension
        """

        # 获取输入张量的形状
        b, c, f, h, w = x.shape
        # 计算卷积输出的深度、高度和宽度
        out_f, out_h, out_w = calc_conv_output((f, h, w), self.kernel_size, self.padding, self.stride)

        # 计算是否需要填充的掩码

        mask = None
        if any([i > 0 for i in self.padding]):
            mask = torch.ones((b, 1, f, h, w), device = x.device)
            mask = self.unfold(mask)
            mask = rearrange(mask, 'b j (f h w) -> b 1 f h w j', f = out_f, h = out_h, w = out_w)
            mask = mask == 1.

        # 对输入张量进行展开操作
        x = self.unfold(x)
        x = rearrange(x, 'b (c j) (f h w) -> b c f h w j', f = out_f, h = out_h, w = out_w, c = c)
        # 调用 logavgexp 函数进行计算,传入掩码
        return self.logavgexp(x, mask = mask)

.\lucidrains\logavgexp-torch\logavgexp_pytorch\__init__.py

# 从logavgexp_pytorch.logavgexp_pytorch模块中导入logavgexp、LogAvgExp、LogAvgExp2D、LogAvgExp3D类和函数
from logavgexp_pytorch.logavgexp_pytorch import logavgexp, LogAvgExp, LogAvgExp2D, LogAvgExp3D

LogAvgExp - Pytorch

Implementation of LogAvgExp for Pytorch

Install

$ pip install logavgexp-pytorch

Usage

import torch
from logavgexp_pytorch import logavgexp

# basically it is an improved logsumexp (differentiable max)
# normalized for length

x = torch.arange(1000)
y = logavgexp(x, dim = 0, temp = 0.01) # ~998.8

# more than 1 dimension

x = torch.randn(1, 2048, 5)
y = logavgexp(x, dim = 1, temp = 0.2) # (1, 5)

# keep dimension

x = torch.randn(1, 2048, 5)
y = logavgexp(x, dim = 1, temp = 0.2, keepdim = True) # (1, 1, 5)

# masking (False for mask out with large negative value)

x = torch.randn(1, 2048, 5)
m = torch.randint(0, 2, (1, 2048, 1)).bool()

y = logavgexp(x, mask = m, dim = 1, temp = 0.2, keepdim = True) # (1, 1, 5)

With learned temperature

# learned temperature
import torch
from torch import nn
from logavgexp_pytorch import logavgexp

learned_temp = nn.Parameter(torch.ones(1) * -5).exp().clamp(min = 1e-8) # make sure temperature can't hit 0

x = torch.randn(1, 2048, 5)
y = logavgexp(x, temp = learned_temp, dim = 1) # (1, 5)

Or you can use the LogAvgExp class to handle the learned temperature parameter

import torch
from logavgexp_pytorch import LogAvgExp

logavgexp = LogAvgExp(
    temp = 0.01,
    dim = 1,
    learned_temp = True
)

x = torch.randn(1, 2048, 5)
y = logavgexp(x) # (1, 5)

LogAvgExp2D

import torch
from logavgexp_pytorch import LogAvgExp2D

logavgexp_pool = LogAvgExp2D((2, 2), stride = 2) # (2 x 2) pooling

img = torch.randn(1, 16, 64, 64)
out = logavgexp_pool(img) # (1, 16, 32, 32)

Todo

Citations

@misc{lowe2021logavgexp,
    title   = {LogAvgExp Provides a Principled and Performant Global Pooling Operator}, 
    author  = {Scott C. Lowe and Thomas Trappenberg and Sageev Oore},
    year    = {2021},
    eprint  = {2111.01742},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

.\lucidrains\logavgexp-torch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'logavgexp-pytorch', # 包的名称
  packages = find_packages(exclude=[]), # 查找所有包
  version = '0.0.6', # 版本号
  license='MIT', # 许可证
  description = 'LogAvgExp - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/logavgexp-pytorch', # 项目链接
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'pytorch',
    'logsumexp'
  ],
  install_requires=[ # 安装依赖
    'einops>=0.4.1',
    'torch>=1.6',
    'unfoldNd'
  ],
  classifiers=[ # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

.\lucidrains\long-short-transformer\long_short_transformer\autoregressive_wrapper.py

import torch
from torch import nn
import torch.nn.functional as F

# 定义一个装饰器函数,用于在模型评估时切换模型状态
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 定义一个函数用于对 logits 进行 top-k 过滤
def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 定义一个包装类,用于自回归模型
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, ignore_index = -100, pad_value = 0):
        super().__init__()
        self.pad_value = pad_value
        self.ignore_index = ignore_index

        self.net = net
        self.max_seq_len = net.max_seq_len

    # 生成序列的方法,支持自定义起始标记、序列长度、结束标记、温度等参数
    @torch.no_grad()
    @eval_decorator
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
        device = start_tokens.device
        num_dims = len(start_tokens.shape)

        if num_dims == 1:
            start_tokens = start_tokens[None, :]

        b, t = start_tokens.shape

        out = start_tokens
        mask = kwargs.pop('mask', None)

        if mask is None:
            mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)

        for _ in range(seq_len):
            x = out[:, -self.max_seq_len:]
            mask = mask[:, -self.max_seq_len:]

            logits = self.net(x, mask=mask, **kwargs)[:, -1, :]

            filtered_logits = top_k(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)

            sample = torch.multinomial(probs, 1)

            out = torch.cat((out, sample), dim=-1)
            mask = F.pad(mask, (0, 1), value=True)

            if eos_token is not None and (sample == eos_token).all():
                break

        out = out[:, t:]

        if num_dims == 1:
            out = out.squeeze(0)

        return out

    # 前向传播方法,计算损失函数
    def forward(self, x, **kwargs):
        xi = x[:, :-1]
        xo = x[:, 1:]

        # 解决自回归模型中输入掩码的常见混淆问题
        mask = kwargs.get('mask', None)
        if mask is not None and mask.shape[1] == x.shape[1]:
            mask = mask[:, :-1]
            kwargs.update(mask = mask)

        out = self.net(xi, **kwargs)
        loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
        return loss

.\lucidrains\long-short-transformer\long_short_transformer\long_short_transformer.py

# 从 math 模块中导入 gcd(最大公约数)和 ceil(向上取整)函数
from math import gcd, ceil
# 导入 functools 模块
import functools

# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn(神经网络)和 einsum(张量乘法)模块
from torch import nn, einsum
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F

# 导入 rotary_embedding_torch 模块中的 RotaryEmbedding 和 apply_rotary_emb 函数
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb

# 导入 einops 模块中的 rearrange 和 repeat 函数
from einops import rearrange, repeat

# 定义函数 exists,判断值是否存在
def exists(val):
    return val is not None

# 定义函数 default,如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义函数 lcm,计算多个数的最小公倍数
def lcm(*numbers):
    return int(functools.reduce(lambda x, y: int((x * y) / gcd(x, y)), numbers, 1))

# 定义函数 pad_to_multiple,将张量的长度填充到指定的倍数
def pad_to_multiple(tensor, multiple, dim = -1, value = 0):
    seqlen = tensor.shape[dim]
    m = seqlen / multiple

    if m.is_integer():
        return tensor

    remainder = ceil(m) * multiple - seqlen
    pad_offset = (0,) * (-1 - dim) * 2
    return F.pad(tensor, (*pad_offset, 0, remainder), value=value)

# 定义函数 look_around,根据给定的向前和向后偏移量,在张量周围填充指定值
def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2):
    t = x.shape[1]
    dims = (len(x.shape) - dim) * (0, 0)
    padded_x = F.pad(x, (*dims, backward, forward), value= pad_value)
    tensors = [padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)]
    return torch.cat(tensors, dim=dim)

# 定义类 PreNorm,实现预层归一化
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

# 定义类 FeedForward,实现前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return self.net(x)

# 定义类 LongShortAttention,实现长短注意力机制
class LongShortAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64,
        causal = True,
        window_size = 128,
        pos_emb = None,
        segment_size = 16,
        r = 1,
        dropout = 0.
    ):
        super().__init__()
        assert not (causal and r >= segment_size), 'r should be less than segment size, if autoregressive'

        inner_dim = heads * dim_head
        self.scale = dim_head ** -0.5

        self.heads = heads
        self.causal = causal

        self.window_size = window_size
        self.segment_size = segment_size
        self.pad_to_multiple = window_size if not causal else lcm(window_size, segment_size)

        self.to_dynamic_proj = nn.Linear(dim_head, r, bias = False)
        self.local_norm = nn.LayerNorm(dim_head)
        self.global_norm = nn.LayerNorm(dim_head)

        self.pos_emb = default(pos_emb, RotaryEmbedding(dim_head))

        self.attn_dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

# 定义主类 LongShortTransformer,实现长短变换器
class LongShortTransformer(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        depth,
        max_seq_len,
        window_size = 128,
        causal = True,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        segment_size = None,
        r = None,
        ff_dropout = 0.,
        attn_dropout = 0.
    ):  
        # 调用父类的构造函数
        super().__init__()
        # 设置最大序列长度
        self.max_seq_len = max_seq_len

        # 创建 token embedding 层
        self.token_emb = nn.Embedding(num_tokens, dim)
        # 创建旋转嵌入层
        pos_emb = RotaryEmbedding(dim_head)

        # 处理自回归默认变量的方式不同
        # 具体来说,segments 仅在自回归情况下使用
        # r 在非自回归情况下是投影的 r << n,在自回归情况下是每个段的投影 r
        # 是的,这很令人困惑,我知道

        # 设置 segment_size 默认值
        segment_size = default(segment_size, 16 if causal else None)
        # 设置 r 默认值
        r = default(r, 1 if causal else 128)

        # 创建多层神经网络
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            # 每层包含一个注意力机制和一个前馈神经网络
            self.layers.append(nn.ModuleList([
                PreNorm(dim, LongShortAttention(dim = dim, heads = heads, dim_head = dim_head, window_size = window_size, causal = causal, pos_emb = pos_emb, segment_size = segment_size, r = r, dropout = attn_dropout)),
                PreNorm(dim, FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout))
            ]))

        # 创建输出层
        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_tokens)
        )

    def forward(self, x, mask = None):
        # 对输入进行 token embedding
        x = self.token_emb(x)

        # 遍历每一层的注意力机制和前馈神经网络
        for attn, ff in self.layers:
            # 注意力机制
            x = attn(x, mask = mask) + x
            # 前馈神经网络
            x = ff(x) + x

        # 输出结果
        return self.to_logits(x)

.\lucidrains\long-short-transformer\long_short_transformer\__init__.py

# 从 long_short_transformer.long_short_transformer 模块中导入 LongShortTransformer 和 LongShortAttention 类
from long_short_transformer.long_short_transformer import LongShortTransformer, LongShortAttention

Long-Short Transformer

Implementation of Long-Short Transformer, combining local and global inductive biases for attention over long sequences, in Pytorch

Install

$ pip install long-short-transformer

Usage

import torch
from long_short_transformer import LongShortTransformer

model = LongShortTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 6,             # how deep
    heads = 8,             # number of heads
    dim_head = 64,         # dimension per head
    max_seq_len = 1024,    # maximum sequence length
    window_size = 128,     # local attention window size
    r = 256                # like linformer, the sequence length is projected down to this value to avoid the quadratic, where r << n (seq len)
)

x = torch.randint(0, 20000, (1, 1024))
mask = torch.ones(1, 1024).bool()

logits = model(x, mask = mask) # (1, 1024, 20000)

For the autoregressive case, you will have to also supply the segment_size and set causal to True

import torch
from long_short_transformer import LongShortTransformer

model = LongShortTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 6,             # how deep
    heads = 8,             # number of heads
    dim_head = 64,         # dimension per head
    causal = True,         # autoregressive or not
    max_seq_len = 1024,    # maximum sequence length
    window_size = 128,     # local attention window size
    segment_size = 16,     # sequence is divided into segments of this size, to be projected down to r
    r = 1                  # paper claimed best results with segment to r of 16:1
)

x = torch.randint(0, 20000, (1, 1024))
mask = torch.ones(1, 1024).bool()

logits = model(x, mask = mask) # (1, 1024, 20000)

You can test the autoregressive on enwik8 with

$ python train.py

Citations

@misc{zhu2021longshort,
    title   = {Long-Short Transformer: Efficient Transformers for Language and Vision}, 
    author  = {Chen Zhu and Wei Ping and Chaowei Xiao and Mohammad Shoeybi and Tom Goldstein and Anima Anandkumar and Bryan Catanzaro},
    year    = {2021},
    eprint  = {2107.02192},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\long-short-transformer\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'long-short-transformer',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.0.5',  # 版本号
  license='MIT',  # 许可证
  description = 'Long Short Transformer - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/long-short-transformer',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'efficient attention'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'rotary-embedding-torch',
    'torch>=1.6'
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\long-short-transformer\train.py

# 导入所需的模块和类
from long_short_transformer import LongShortTransformer
from long_short_transformer.autoregressive_wrapper import AutoregressiveWrapper

import random
import tqdm
import gzip
import numpy as np

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# 定义常量
NUM_BATCHES = int(1e6)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 3e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512
SEQ_LEN = 1024

# 定义辅助函数

# 将 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 将 tokens 解码为字符串
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 实例化类 GPT-like decoder model

model = LongShortTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    heads = 8,
    max_seq_len = SEQ_LEN,
    causal = True,
    window_size = 128
)

model = AutoregressiveWrapper(model)
model.cuda()

# 准备 enwik8 数据

with gzip.open('./data/enwik8.gz') as file:
    data = np.fromstring(file.read(int(95e6)), dtype = np.uint8)
    data_train, data_val = map(torch.from_numpy, np.split(data, [int(90e6)]))

# 定义 Dataset 类用于采样文本数据
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward()

    print(f'training loss: {loss.item()}')

    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)

.\lucidrains\lumiere-pytorch\lumiere_pytorch\lumiere.py

"""
einstein notation
b - batch
t - time
c - channels
h - height
w - width
"""

from copy import deepcopy
from functools import wraps

import torch
from torch import nn, einsum, Tensor, is_tensor
from torch.nn import Module, ModuleList
import torch.nn.functional as F

from beartype import beartype
from beartype.typing import List, Tuple, Optional, Type

from einops import rearrange, pack, unpack, repeat

from optree import tree_flatten, tree_unflatten

from x_transformers.x_transformers import (
    Attention,
    RMSNorm
)

# helpers

# 检查变量是否存在
def exists(v):
    return v is not None

# 如果变量存在则返回变量,否则返回默认值
def default(v, d):
    return v if exists(v) else d

# 将单个张量按照指定模式打包
def pack_one(t, pattern):
    return pack([t], pattern)

# 将单个张量按照指定模式解包
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 判断一个数是否可以被另一个数整除
def divisible_by(num, den):
    return (num % den) == 0

# 判断一个数是否为奇数
def is_odd(n):
    return not divisible_by(n, 2)

# 压缩字典中存在值的键值对
def compact_values(d: dict):
    return {k: v for k, v in d.items() if exists(v)}

# extract dimensions using hooks

# 使用钩子函数提取模块的输出形状
@beartype
def extract_output_shapes(
    modules: List[Module],
    model: Module,
    model_input,
    model_kwargs: dict = dict()
):
    shapes = []
    hooks = []

    def hook_fn(_, input, output):
        return shapes.append(output.shape)

    for module in modules:
        hook = module.register_forward_hook(hook_fn)
        hooks.append(hook)

    with torch.no_grad():
        model(model_input, **model_kwargs)

    for hook in hooks:
        hook.remove()

    return shapes

# freezing text-to-image, and only learning temporal parameters

# 冻结所有层,只学习时间参数
@beartype
def set_module_requires_grad_(
    module: Module,
    requires_grad: bool
):
    for param in module.parameters():
        param.requires_grad = requires_grad

def freeze_all_layers_(module):
    set_module_requires_grad_(module, False)

# function that takes in the entire text-to-video network, and sets the time dimension

# 设置时间维度
def set_time_dim_(
    klasses: Tuple[Type[Module]],
    model: Module,
    time_dim: int
):
    for model in model.modules():
        if isinstance(model, klasses):
            model.time_dim = time_dim

# decorator for residual

# 用于添加残差的装饰器
def residualize(fn):
    @wraps(fn)
    def inner(
        self,
        x,
        *args,
        **kwargs
    ):
        residual = x
        out = fn(self, x, *args, **kwargs)
        return out + residual

    return inner

# decorator for converting an input tensor from either image or video format to 1d time

# 将输入张量从图像或视频格式转换为1维时间的装饰器
def image_or_video_to_time(fn):

    @wraps(fn)
    def inner(
        self,
        x,
        batch_size = None,
        **kwargs
    ):

        is_video = x.ndim == 5

        if is_video:
            batch_size = x.shape[0]
            x = rearrange(x, 'b c t h w -> b h w c t')
        else:
            assert exists(batch_size) or exists(self.time_dim)
            rearrange_kwargs = dict(b = batch_size, t = self.time_dim)
            x = rearrange(x, '(b t) c h w -> b h w c t', **compact_values(rearrange_kwargs))

        x, ps = pack_one(x, '* c t')

        x = fn(self, x, **kwargs)

        x = unpack_one(x, ps, '* c t')

        if is_video:
            x = rearrange(x, 'b h w c t -> b c t h w')
        else:
            x = rearrange(x, 'b h w c t -> (b t) c h w')

        return x

    return inner

# handle channel last

# 处理通道在最后的情况
def handle_maybe_channel_last(fn):

    @wraps(fn)
    def inner(
        self,
        x,
        *args,
        **kwargs
    ):

        if self.channel_last:
            x = rearrange(x, 'b c ... -> b ... c')

        out = fn(self, x, *args, **kwargs)

        if self.channel_last:
            out = rearrange(out, 'b c ... -> b ... c')

        return out

    return inner

# helpers

# 创建一个序列模块,过滤掉不存在的模块
def Sequential(*modules):
    modules = list(filter(exists, modules))
    return nn.Sequential(*modules)

# 定义一个带有残差连接的模块
class Residual(Module):
    def __init__(self, fn: Module):
        super().__init__()
        self.fn = fn

    def forward(self, t, *args, **kwargs):
        return self.fn(t, *args, **kwargs) + t

# temporal down and upsample
# 初始化一维双线性插值卷积核
def init_bilinear_kernel_1d_(conv: Module):
    # 初始化卷积核权重为零
    nn.init.zeros_(conv.weight)
    # 如果存在偏置项,初始化为零
    if exists(conv.bias):
        nn.init.zeros_(conv.bias)

    # 获取卷积核的通道数
    channels = conv.weight.shape[0]
    # 创建双线性插值核
    bilinear_kernel = Tensor([0.5, 1., 0.5])
    # 创建对角线掩码
    diag_mask = torch.eye(channels).bool()
    # 将双线性插值核应用到卷积核的对角线位置
    conv.weight.data[diag_mask] = bilinear_kernel

# 时间下采样模块
class TemporalDownsample(Module):
    def __init__(
        self,
        dim,
        channel_last = False,
        time_dim = None
    ):
        super().__init__()
        self.time_dim = time_dim
        self.channel_last = channel_last

        # 创建一维卷积层,用于时间下采样
        self.conv = nn.Conv1d(dim, dim, kernel_size = 3, stride = 2, padding = 1)
        # 初始化卷积核为双线性插值核
        init_bilinear_kernel_1d_(self.conv)

    # 前向传播函数
    @handle_maybe_channel_last
    @image_or_video_to_time
    def forward(
        self,
        x
    ):
        # 断言时间维度大于1,以便进行压缩
        assert x.shape[-1] > 1, 'time dimension must be greater than 1 to be compressed'

        return self.conv(x)

# 时间上采样模块
class TemporalUpsample(Module):
    def __init__(
        self,
        dim,
        channel_last = False,
        time_dim = None
    ):
        super().__init__()
        self.time_dim = time_dim
        self.channel_last = channel_last

        # 创建一维转置卷积层,用于时间上采样
        self.conv = nn.ConvTranspose1d(dim, dim, kernel_size = 3, stride = 2, padding = 1, output_padding = 1)
        # 初始化卷积核为双线性插值核
        init_bilinear_kernel_1d_(self.conv)

    # 前向传播函数
    @handle_maybe_channel_last
    @image_or_video_to_time
    def forward(
        self,
        x
    ):
        return self.conv(x)

# 卷积膨胀块
class ConvolutionInflationBlock(Module):
    def __init__(
        self,
        *,
        dim,
        conv2d_kernel_size = 3,
        conv1d_kernel_size = 3,
        groups = 8,
        channel_last = False,
        time_dim = None
    ):
        super().__init__()
        assert is_odd(conv2d_kernel_size)
        assert is_odd(conv1d_kernel_size)

        self.time_dim = time_dim
        self.channel_last = channel_last

        # 空间卷积层
        self.spatial_conv = nn.Sequential(
            nn.Conv2d(dim, dim, conv2d_kernel_size, padding = conv2d_kernel_size // 2),
            nn.GroupNorm(groups, num_channels = dim),
            nn.SiLU()
        )

        # 时间卷积层
        self.temporal_conv = nn.Sequential(
            nn.Conv1d(dim, dim, conv1d_kernel_size, padding = conv1d_kernel_size // 2),
            nn.GroupNorm(groups, num_channels = dim),
            nn.SiLU()
        )

        # 投影输出层
        self.proj_out = nn.Conv1d(dim, dim, 1)

        # 初始化投影输出层的权重和偏置为零
        nn.init.zeros_(self.proj_out.weight)
        nn.init.zeros_(self.proj_out.bias)

    # 前向传播函数
    @residualize
    @handle_maybe_channel_last
    def forward(
        self,
        x,
        batch_size = None
    ):
        is_video = x.ndim == 5

        if is_video:
            batch_size = x.shape[0]
            x = rearrange(x, 'b c t h w -> (b t) c h w')

        x = self.spatial_conv(x)

        rearrange_kwargs = compact_values(dict(b = batch_size, t = self.time_dim))

        assert len(rearrange_kwargs) > 0, 'either batch_size is passed in on forward, or time_dim is set on init'
        x = rearrange(x, '(b t) c h w -> b h w c t', **rearrange_kwargs)

        x, ps = pack_one(x, '* c t')

        x = self.temporal_conv(x)
        x = self.proj_out(x)

        x = unpack_one(x, ps, '* c t')

        if is_video:
            x = rearrange(x, 'b h w c t -> b c t h w')
        else:
            x = rearrange(x, 'b h w c t -> (b t) c h w')

        return x

# 注意力膨胀块
class AttentionInflationBlock(Module):
    def __init__(
        self,
        *,
        dim,
        depth = 1,
        prenorm = True,
        residual_attn = True,
        time_dim = None,
        channel_last = False,
        **attn_kwargs
    # 初始化函数,继承父类的初始化方法
    def __init__(
        self,
        time_dim,
        channel_last,
        depth,
        dim,
        attn_kwargs = {},
        prenorm = False,
        residual_attn = False
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 设置时间维度和是否通道在最后的标志
        self.time_dim = time_dim
        self.channel_last = channel_last

        # 初始化时间注意力模块列表
        self.temporal_attns = ModuleList([])

        # 根据深度循环创建注意力模块
        for _ in range(depth):
            # 创建注意力模块序列
            attn = Sequential(
                RMSNorm(dim) if prenorm else None,
                Attention(
                    dim = dim,
                    **attn_kwargs
                )
            )

            # 如果开启残差连接,则将注意力模块包装成残差模块
            if residual_attn:
                attn = Residual(attn)

            # 将创建的注意力模块添加到时间注意力模块列表中
            self.temporal_attns.append(attn)

        # 创建输出投影层
        self.proj_out = nn.Linear(dim, dim)

        # 初始化输出投影层的权重和偏置为零
        nn.init.zeros_(self.proj_out.weight)
        nn.init.zeros_(self.proj_out.bias)

    # 前向传播函数,添加了装饰器
    @residualize
    @handle_maybe_channel_last
    def forward(
        self,
        x,
        batch_size = None
    ):
        # 判断输入是否为视频数据
        is_video = x.ndim == 5
        # 断言判断输入数据维度是否符合要求
        assert is_video ^ (exists(batch_size) or exists(self.time_dim)), 'either a tensor of shape (batch, channels, time, height, width) is passed in, or (batch * time, channels, height, width) along with `batch_size`'

        # 如果通道在最后,则重新排列输入数据
        if self.channel_last:
            x = rearrange(x, 'b ... c -> b c ...')

        # 如果是视频数据,则重新排列输入数据
        if is_video:
            batch_size = x.shape[0]
            x = rearrange(x, 'b c t h w -> b h w t c')
        else:
            assert exists(batch_size) or exists(self.time_dim)

            rearrange_kwargs = dict(b = batch_size, t = self.time_dim)
            x = rearrange(x, '(b t) c h w -> b h w t c', **compact_values(rearrange_kwargs))

        # 打包输入数据
        x, ps = pack_one(x, '* t c')

        # 遍历时间注意力模块列表,对输入数据进行注意力���作
        for attn in self.temporal_attns:
            x = attn(x)

        # 输出投影层处理数据
        x = self.proj_out(x)

        # 解包数据
        x = unpack_one(x, ps, '* t c')

        # 根据是否为视频数据重新排列输出数据
        if is_video:
            x = rearrange(x, 'b h w t c -> b c t h w')
        else:
            x = rearrange(x, 'b h w t c -> (b t) c h w')

        # 如果通道在最后,则重新排列输出数据
        if self.channel_last:
            x = rearrange(x, 'b c ... -> b ... c')

        # 返回处理后的输出数据
        return x
# 定义一个包装器类,用于在模块后添加钩子
class PostModuleHookWrapper(Module):
    def __init__(self, temporal_module: Module):
        super().__init__()
        self.temporal_module = temporal_module

    # 在前向传播过程中,对输出进行处理并返回
    def forward(self, _, input, output):
        output = self.temporal_module(output)
        return output

# 将临时模块插入到模块列表中
def insert_temporal_modules_(modules: List[Module], temporal_modules: ModuleList):
    assert len(modules) == len(temporal_modules)

    # 遍历模块列表和临时模块列表,为每个模块注册一个后向钩子
    for module, temporal_module in zip(modules, temporal_modules):
        module.register_forward_hook(PostModuleHookWrapper(temporal_module))

# 主要的文本到图像模型包装器
class Lumiere(Module):

    # 初始化函数
    @beartype
    def __init__(
        self,
        model: Module,
        *,
        image_size: int,
        unet_time_kwarg: str,
        conv_module_names: List[str],
        attn_module_names: List[str] = [],
        downsample_module_names: List[str] = [],
        upsample_module_names: List[str] = [],
        channels: int = 3,
        conv_inflation_kwargs: dict = dict(),
        attn_inflation_kwargs: dict = dict(),
        downsample_kwargs: dict = dict(),
        upsample_kwargs: dict = dict(),
        conv_klass = ConvolutionInflationBlock,
        attn_klass = AttentionInflationBlock,
        downsample_klass = TemporalDownsample,
        upsample_klass = TemporalUpsample
    @property
    def downsample_factor(self):
        return 2 ** len(self.downsamples)

    # 返回模型的参数
    def parameters(self):
        return [
            *self.convs.parameters(),
            *self.attns.parameters(),
            *self.downsamples.parameters(),
            *self.upsamples.parameters(),
        ]

    # 前向传播函数
    @beartype
    def forward(
        self,
        video: Tensor,
        *args,
        **kwargs
    ) -> Tensor:

        assert video.ndim == 5
        batch, channels, time, height, width = video.shape

        assert channels == self.channels
        assert (height, width) == (self.image_size, self.image_size)

        assert divisible_by(time, self.downsample_factor)

        # 将视频转换为一堆图像
        images = rearrange(video, 'b c t h w -> (b t) c h w')

        # 为所有时间层设置正确的时间维度
        set_time_dim_(self.temporal_klasses, self, time)

        # 将所有图像传入文本到图像模型
        images = self.model(images, *args, **kwargs)

        # 将结果重塑回去成去噪视频
        return rearrange(images, '(b t) c h w -> b c t h w', b = batch)

.\lucidrains\lumiere-pytorch\lumiere_pytorch\mp_lumiere.py

# 导入所需的库
from math import sqrt
from functools import partial
import torch
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList
import torch.nn.functional as F
# 导入 beartype 库,用于类型注解
from beartype import beartype
from beartype.typing import List, Tuple, Optional
# 导入 einops 库,用于操作张量
from einops import rearrange, pack, unpack, repeat
# 导入 lumiere 库中的函数
from lumiere_pytorch.lumiere import (
    image_or_video_to_time,
    handle_maybe_channel_last,
    Lumiere
)

# 定义一些辅助函数

# 判断变量是否存在
def exists(v):
    return v is not None

# 如果变量存在则返回变量,否则返回默认值
def default(v, d):
    return v if exists(v) else d

# 将张量打包成指定模式的形状
def pack_one(t, pattern):
    return pack([t], pattern)

# 将打包后的张量解包成指定模式的形状
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 压缩字典中值不存在的键值对
def compact_values(d: dict):
    return {k: v for k, v in d.items() if exists(v)}

# 计算 L2 范数
def l2norm(t, dim = -1, eps = 1e-12):
    return F.normalize(t, dim = dim, eps = eps)

# 对权重进行归一化处理
def normalize_weight(weight, eps = 1e-4):
    weight, ps = pack_one(weight, 'o *')
    normed_weight = l2norm(weight, eps = eps)
    normed_weight = normed_weight * sqrt(weight.numel() / weight.shape[0])
    return unpack_one(normed_weight, ps, 'o *')

# 在一维上进行插值
def interpolate_1d(x, length, mode = 'bilinear'):
    x = rearrange(x, 'b c t -> b c t 1')
    x = F.interpolate(x, (length, 1), mode = mode)
    return rearrange(x, 'b c t 1 -> b c t')

# MP 激活函数
class MPSiLU(Module):
    def forward(self, x):
        return F.silu(x) / 0.596

# 增益 - 层缩放
class Gain(Module):
    def __init__(self):
        super().__init__()
        self.gain = nn.Parameter(torch.tensor(0.))

    def forward(self, x):
        return x * self.gain

# MP 线性层
class Linear(Module):
    def __init__(self, dim_in, dim_out, eps = 1e-4):
        super().__init__()
        weight = torch.randn(dim_out, dim_in)
        self.weight = nn.Parameter(weight)
        self.eps = eps
        self.fan_in = dim_in

    def forward(self, x):
        if self.training:
            with torch.no_grad():
                normed_weight = normalize_weight(self.weight, eps = self.eps)
                self.weight.copy_(normed_weight)

        weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)
        return F.linear(x, weight)

# 强制权重归一化的卷积层和线性层
class Conv2d(Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        kernel_size,
        eps = 1e-4
    ):
        super().__init__()
        weight = torch.randn(dim_out, dim_in, kernel_size, kernel_size)
        self.weight = nn.Parameter(weight)

        self.eps = eps
        self.fan_in = dim_in * kernel_size ** 2

    def forward(self, x):
        if self.training:
            with torch.no_grad():
                normed_weight = normalize_weight(self.weight, eps = self.eps)
                self.weight.copy_(normed_weight)

        weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)
        return F.conv2d(x, weight, padding = 'same')

class Conv1d(Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        kernel_size,
        eps = 1e-4,
        init_dirac = False
    ):
        super().__init__()
        weight = torch.randn(dim_out, dim_in, kernel_size)
        self.weight = nn.Parameter(weight)

        if init_dirac:
            nn.init.dirac_(self.weight)

        self.eps = eps
        self.fan_in = dim_in * kernel_size

    def forward(self, x):
        if self.training:
            with torch.no_grad():
                normed_weight = normalize_weight(self.weight, eps = self.eps)
                self.weight.copy_(normed_weight)

        weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)
        return F.conv1d(x, weight, padding = 'same')

# 像素归一化层
class PixelNorm(Module):
    # 初始化函数,设置维度和epsilon值
    def __init__(self, dim, eps = 1e-4):
        # 调用父类的初始化函数
        super().__init__()
        # 设置像素规范化的高epsilon值
        self.dim = dim
        self.eps = eps

    # 前向传播函数
    def forward(self, x):
        # 获取维度
        dim = self.dim
        # 返回经过L2范数规范化后的结果乘以维度的平方根
        return l2norm(x, dim = dim, eps = self.eps) * sqrt(x.shape[dim])
# 定义一个类,实现magnitude preserving sum的功能
# t的值根据经验设定为0.3,用于encoder/decoder/attention residuals和embedding
class MPAdd(Module):
    def __init__(self, t):
        super().__init__()
        self.t = t

    # 实现前向传播功能
    def forward(self, x, res):
        a, b, t = x, res, self.t
        num = a * (1. - t) + b * t
        den = sqrt((1 - t) ** 2 + t ** 2)
        return num / den

# 定义一个类,实现mp attention的功能
class MPAttention(Module):
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 64,
        num_mem_kv = 4,
        mp_add_t = 0.3,
        dropout = 0.
    ):
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads

        self.scale = dim_head ** -0.5
        self.pixel_norm = PixelNorm(dim = -1)

        self.dropout = nn.Dropout(dropout)

        self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head))
        self.to_qkv = Linear(dim, hidden_dim * 3)
        self.to_out = Linear(hidden_dim, dim)

        self.mp_add = MPAdd(t = mp_add_t)

    # 实现前向传播功能
    def forward(self, x):
        res, b = x, x.shape[0]

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), self.mem_kv)
        k, v = map(partial(torch.cat, dim = -2), ((mk, k), (mv, v)))

        q, k, v = map(self.pixel_norm, (q, k, v))

        q = q * self.scale

        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)

        return self.mp_add(out, res)

# 定义一个类,实现时间维度的下采样
class MPTemporalDownsample(Module):
    def __init__(
        self,
        dim,
        channel_last = False,
        time_dim = None
    ):
        super().__init__()
        self.time_dim = time_dim
        self.channel_last = channel_last
        self.conv = Conv1d(dim, dim, 3, init_dirac = True)

    # 实现前向传播功能
    @handle_maybe_channel_last
    @image_or_video_to_time
    def forward(
        self,
        x
    ):
        t = x.shape[-1]
        assert t > 1, 'time dimension must be greater than 1 to be compressed'

        x = interpolate_1d(x, t // 2)
        return self.conv(x)

# 定义一个类,实现时间维度的上采样
class MPTemporalUpsample(Module):
    def __init__(
        self,
        dim,
        channel_last = False,
        time_dim = None
    ):
        super().__init__()
        self.time_dim = time_dim
        self.channel_last = channel_last
        self.conv = Conv1d(dim, dim, 3, init_dirac = True)

    # 实现前向传播功能
    @handle_maybe_channel_last
    @image_or_video_to_time
    def forward(
        self,
        x
    ):
        t = x.shape[-1]
        x = interpolate_1d(x, t * 2)
        return self.conv(x)

# 定义一个类,实现MP卷积膨胀块的功能
class MPConvolutionInflationBlock(Module):
    def __init__(
        self,
        *,
        dim,
        conv2d_kernel_size = 3,
        conv1d_kernel_size = 3,
        channel_last = False,
        time_dim = None,
        mp_add_t = 0.3,
        dropout = 0.
    ):
        super().__init__()
        self.time_dim = time_dim
        self.channel_last = channel_last

        self.spatial_conv = nn.Sequential(
            Conv2d(dim, dim, conv2d_kernel_size, 3),
            MPSiLU()
        )

        self.temporal_conv = nn.Sequential(
            Conv1d(dim, dim, conv1d_kernel_size, 3),
            MPSiLU(),
            nn.Dropout(dropout)
        )

        self.proj_out = nn.Sequential(
            Conv1d(dim, dim, 1),
            Gain()
        )

        self.residual_mp_add = MPAdd(t = mp_add_t)

    # 实现前向传播功能
    @handle_maybe_channel_last
    def forward(
        self,
        x,
        batch_size = None
        ):
        # 将输入赋值给残差变量
        residual = x

        # 判断输入是否为视频,判断输入的维度是否为5
        is_video = x.ndim == 5

        # 如果是视频
        if is_video:
            # 获取批量大小
            batch_size = x.shape[0]
            # 重新排列输入数据的维度
            x = rearrange(x, 'b c t h w -> (b t) c h w')

        # 对输入进行空间卷积
        x = self.spatial_conv(x)

        # 重新排列参数
        rearrange_kwargs = compact_values(dict(b = batch_size, t = self.time_dim))

        # 断言重新排列参数的长度大于0
        assert len(rearrange_kwargs) > 0, 'either batch_size is passed in on forward, or time_dim is set on init'
        # 重新排列输入数据的维度
        x = rearrange(x, '(b t) c h w -> b h w c t', **rearrange_kwargs)

        # 打包输入数据
        x, ps = pack_one(x, '* c t')

        # 对输入进行时间卷积
        x = self.temporal_conv(x)
        # 对输入进行投影输出
        x = self.proj_out(x)

        # 解包输入数据
        x = unpack_one(x, ps, '* c t')

        # 如果是视频
        if is_video:
            # 重新排列输入数据的维度
            x = rearrange(x, 'b h w c t -> b c t h w')
        else:
            # 重新排列输入数据的维度
            x = rearrange(x, 'b h w c t -> (b t) c h w')

        # 返回残差模块添加后的结果
        return self.residual_mp_add(x, residual)
# 定义一个多头注意力膨胀块类,继承自 Module 类
class MPAttentionInflationBlock(Module):
    # 初始化函数
    def __init__(
        self,
        *,
        dim,  # 维度
        depth = 1,  # 层数,默认为1
        time_dim = None,  # 时间维度,默认为None
        channel_last = False,  # 是否通道在最后,默认为False
        mp_add_t = 0.3,  # MP 添加时间,默认为0.3
        dropout = 0.,  # 丢弃率,默认为0
        **attn_kwargs  # 其他注意力参数
    ):
        super().__init__()

        self.time_dim = time_dim  # 初始化时间维度
        self.channel_last = channel_last  # 初始化通道在最后

        self.temporal_attns = ModuleList([])  # 初始化时间注意力模块列表

        # 循环创建指定层数的多头注意力模块
        for _ in range(depth):
            attn = MPAttention(
                dim = dim,
                dropout = dropout,
                **attn_kwargs
            )

            self.temporal_attns.append(attn)  # 将创建的多头注意力模块添加到列表中

        # 定义输出投影层
        self.proj_out = nn.Sequential(
            Linear(dim, dim),  # 线性层
            Gain()  # 增益层
        )

        # 定义残差 MP 添加层
        self.residual_mp_add = MPAdd(t = mp_add_t)

    # 前向传播函数
    @handle_maybe_channel_last
    def forward(
        self,
        x,  # 输入张量
        batch_size = None  # 批量大小,默认为None
    ):
        is_video = x.ndim == 5  # 判断是否为视频数据
        assert is_video ^ (exists(batch_size) or exists(self.time_dim)), 'either a tensor of shape (batch, channels, time, height, width) is passed in, or (batch * time, channels, height, width) along with `batch_size`'

        if self.channel_last:
            x = rearrange(x, 'b ... c -> b c ...')  # 重新排列张量维度

        if is_video:
            batch_size = x.shape[0]  # 获取批量大小
            x = rearrange(x, 'b c t h w -> b h w t c')  # 重新排列张量维度
        else:
            assert exists(batch_size) or exists(self.time_dim)  # 断言批量大小或时间维度存在

            rearrange_kwargs = dict(b = batch_size, t = self.time_dim)
            x = rearrange(x, '(b t) c h w -> b h w t c', **compact_values(rearrange_kwargs))  # 重新排列张量维度

        x, ps = pack_one(x, '* t c')  # 打包张量

        residual = x  # 保存残差

        # 遍历时间注意力模块列表
        for attn in self.temporal_attns:
            x = attn(x)  # 多头注意��操作

        x = self.proj_out(x)  # 投影输出

        x = self.residual_mp_add(x, residual)  # 残差 MP 添加

        x = unpack_one(x, ps, '* t c')  # 解包张量

        if is_video:
            x = rearrange(x, 'b h w t c -> b c t h w')  # 重新排列张量维度
        else:
            x = rearrange(x, 'b h w t c -> (b t) c h w')  # 重新排列张量维度

        if self.channel_last:
            x = rearrange(x, 'b c ... -> b ... c')  # 重新排列张量维度

        return x  # 返回结果张量

# MPLumiere 是 Lumiere 的一个部分,包含四个 MP 时间模块
MPLumiere = partial(
    Lumiere,
    conv_klass = MPConvolutionInflationBlock,  # 卷积类
    attn_klass = MPAttentionInflationBlock,  # 注意力类
    downsample_klass = MPTemporalDownsample,  # 下采样类
    upsample_klass = MPTemporalUpsample  # 上采样类
)

.\lucidrains\lumiere-pytorch\lumiere_pytorch\__init__.py

# 从lumiere_pytorch.lumiere模块中导入ConvolutionInflationBlock、AttentionInflationBlock、TemporalDownsample、TemporalUpsample、set_time_dim_函数
from lumiere_pytorch.lumiere import (
    ConvolutionInflationBlock,
    AttentionInflationBlock,
    TemporalDownsample,
    TemporalUpsample,
    set_time_dim_
)

# 从lumiere_pytorch.lumiere模块中导入Lumiere类
from lumiere_pytorch.lumiere import Lumiere

# 从lumiere_pytorch.mp_lumiere模块中导入MPLumiere、MPConvolutionInflationBlock、MPAttentionInflationBlock、MPTemporalDownsample、MPTemporalUpsample类
from lumiere_pytorch.mp_lumiere import (
    MPLumiere,
    MPConvolutionInflationBlock,
    MPAttentionInflationBlock,
    MPTemporalDownsample,
    MPTemporalUpsample,
)

Lumiere - Pytorch

Implementation of Lumiere, SOTA text-to-video generation from Google Deepmind, in Pytorch

Yannic's paper review

Since this paper is mostly just a few key ideas on top of text-to-image model, will take it a step further and extend the new Karras U-net to video within this repository.

Appreciation

Install

$ pip install lumiere-pytorch

Usage

import torch
from lumiere_pytorch import MPLumiere

from denoising_diffusion_pytorch import KarrasUnet

karras_unet = KarrasUnet(
    image_size = 256,
    dim = 8,
    channels = 3,
    dim_max = 768
)

lumiere = MPLumiere(
    karras_unet,
    image_size = 256,
    unet_time_kwarg = 'time',
    conv_module_names = [
        'downs.1',
        'ups.1'
    ],
    attn_module_names = [
        'mids.0'
    ],
    upsample_module_names = [
        'ups.1'
    ],
    downsample_module_names = [
        'downs.1'
    ]
)

noised_video = torch.randn(2, 3, 8, 256, 256)
time = torch.ones(2,)

denoised_video = lumiere(noised_video, time = time)

assert noised_video.shape == denoised_video.shape

Todo

Citations

@inproceedings{BarTal2024LumiereAS,
    title   = {Lumiere: A Space-Time Diffusion Model for Video Generation},
    author  = {Omer Bar-Tal and Hila Chefer and Omer Tov and Charles Herrmann and Roni Paiss and Shiran Zada and Ariel Ephrat and Junhwa Hur and Yuanzhen Li and Tomer Michaeli and Oliver Wang and Deqing Sun and Tali Dekel and Inbar Mosseri},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:267095113}
}
@article{Karras2023AnalyzingAI,
    title   = {Analyzing and Improving the Training Dynamics of Diffusion Models},
    author  = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2312.02696},
    url     = {https://api.semanticscholar.org/CorpusID:265659032}
}

.\lucidrains\lumiere-pytorch\setup.py

# 导入设置安装和查找包的函数
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包名
  name = 'lumiere-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.0.20',
  # 许可证
  license='MIT',
  # 描述
  description = 'Lumiere',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/lumiere-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'deep learning',
    'text-to-video'
  ],
  # 安装依赖
  install_requires=[
    'beartype',
    'einops>=0.7.0',
    'optree',
    'torch>=2.0',
    'x-transformers'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\magvit2-pytorch\magvit2_pytorch\attend.py

# 导入所需模块和库
from functools import partial
from typing import Optional, Tuple

import torch
from torch import nn, einsum, Tensor
import torch.nn.functional as F

from collections import namedtuple
from functools import wraps
from packaging import version
from dataclasses import dataclass

from einops import rearrange, repeat

# 定义一个命名元组EfficientAttentionConfig,包含三个布尔类型的参数
EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 过滤掉列表中的空值
def compact(arr):
    return [*filter(exists, arr)]

# 保证函数只执行一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 打印函数的输出,确保只打印一次
print_once = once(print)

# 用于创建因果掩码的函数
# 针对onnx cpu需要特殊处理(不支持.triu)

# 创建因果掩码
def create_causal_mask(i, j, device):
    return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)

# 针对onnx创建因果掩码
def onnx_create_causal_mask(i, j, device):
    r = torch.arange(i, device = device)
    causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j')
    causal_mask = F.pad(causal_mask, (j - i, 0), value = False)
    return causal_mask

# 主类

class Attend(nn.Module):
    def __init__(
        self,
        *,
        dropout = 0.,
        causal = False,
        heads = None,
        scale = None,
        flash = False,
        onnxable = False,
        sdp_kwargs: dict = dict(
            enable_flash = True,
            enable_math = True,
            enable_mem_efficient = True
        )
    ):
        super().__init__()
        self.scale = scale

        self.causal = causal
        self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask

        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        # flash attention

        # 检查是否支持flash attention
        self.flash = flash and torch.cuda.is_available()
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        self.sdp_kwargs = sdp_kwargs

    def flash_attn(
        self,
        q, k, v,
        mask = None,
        attn_bias = None
    ):
        # 解包输入张量的形状和其他属性
        batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

        # 使输入张量连续
        q, k, v = map(lambda t: t.contiguous(), (q, k, v))

        # 处理缩放,因为在 sdp 中缩放不可定制,对其进行处理
        if exists(self.scale):
            q = q * self.scale / (q.shape[-1] ** -0.5)

        # 检查是否存在 mask 并扩展到兼容的形状
        causal = self.causal

        # 如果 q_len == 1 且 causal 为真,则将 causal 设置为 False
        if q_len == 1 and causal:
            causal = False

        # 扩展键填充 mask
        if exists(mask):
            assert mask.ndim == 4
            mask = mask.expand(batch, heads, q_len, k_len)

        # 处理 kv 缓存
        if k_len > q_len and causal:
            causal_mask = self.create_causal_mask(q_len, k_len, device=device)
            if not exists(mask):
                mask = ~causal_mask
            else:
                mask = mask & ~causal_mask
            causal = False

        # 手动处理 causal mask,如果给定了另一个 mask
        row_is_entirely_masked = None
        if exists(mask) and causal:
            causal_mask = self.create_causal_mask(q_len, k_len, device=device)
            mask = mask & ~causal_mask

            # 防止整行被屏蔽
            row_is_entirely_masked = ~mask.any(dim=-1)
            mask[..., 0] = mask[..., 0] | row_is_entirely_masked
            causal = False

        # 处理 alibi 位置偏差,将 bool 转换为 float
        if exists(attn_bias):
            attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1)

            mask_value = -torch.finfo(q.dtype).max

            if exists(mask):
                attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
            elif causal:
                causal_mask = self.create_causal_mask(q_len, k_len, device=device)
                attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
                causal = False

            mask = attn_bias

        # 使用 scaled_dot_product_attention 处理注意力
        with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask=mask,
                dropout_p=self.dropout if self.training else 0.,
                is_causal=causal
            )

        # 对于整行被完全屏蔽的情况,将输出的该行标记为 0
        if exists(row_is_entirely_masked):
            out = out.masked_fill(row_is_entirely_masked[..., None], 0.)

        return out

    # 前向传播函数
    def forward(
        self,
        q, k, v,
        mask=None,
        attn_bias=None,
        prev_attn=None
        ):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        # 获取输入张量的形状信息
        n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device

        # 计算缩放因子
        scale = default(self.scale, q.shape[-1] ** -0.5)

        # 获取是否为因果注意力的标志
        causal = self.causal

        # 处理缓存的键值对解码
        if n == 1 and causal:
            causal = False

        # 处理零键值对,允许网络关注空内容
        if self.flash:
            assert not exists(prev_attn), 'residual attention not compatible with flash attention'
            return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)

        # 计算点积注意力得分
        dots = einsum(f'b h i d, b h j d -> b h i j', q, k) * scale

        # 如果存在先前的注意力,加上先前的注意力得分
        if exists(prev_attn):
            dots = dots + prev_attn

        # 如果存在注意力偏置,加上注意力偏置
        if exists(attn_bias):
            dots = dots + attn_bias

        # 获取点积张量的形状信息和数据类型
        i, j, dtype = *dots.shape[-2:], dots.dtype

        # 定义掩码值
        mask_value = -torch.finfo(dots.dtype).max

        # 如果存在掩码,用掩码值填充不需要关注的位置
        if exists(mask):
            dots = dots.masked_fill(~mask, mask_value)

        # 如果是因果注意力,创建因果掩码并用掩码值填充
        if causal:
            causal_mask = self.create_causal_mask(i, j, device = device)
            dots = dots.masked_fill(causal_mask, mask_value)

        # 计算注意力权重
        attn = dots.softmax(dim = -1)

        # 对注意力权重进行dropout
        attn = self.attn_dropout(attn)

        # 计算输出
        out = einsum(f'b h i j, b h j d -> b h i d', attn, v)

        return out

.\lucidrains\magvit2-pytorch\magvit2_pytorch\data.py

# 导入必要的库
from pathlib import Path
from functools import partial

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader as PytorchDataLoader

import cv2
from PIL import Image
from torchvision import transforms as T, utils

from beartype import beartype
from beartype.typing import Tuple, List
from beartype.door import is_bearable

import numpy as np

from einops import rearrange

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 返回输入值
def identity(t, *args, **kwargs):
    return t

# 将输入值转换为元组
def pair(val):
    return val if isinstance(val, tuple) else (val, val)

# 在指定维度上填充张量
def pad_at_dim(t, pad, dim = -1, value = 0.):
    dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
    zeros = ((0, 0) * dims_from_right)
    return F.pad(t, (*zeros, *pad), value = value)

# 调整张量的帧数
def cast_num_frames(t, *, frames):
    f = t.shape[-3]

    if f == frames:
        return t

    if f > frames:
        return t[..., :frames, :, :]

    return pad_at_dim(t, (0, frames - f), dim = -3)

# 将图像转换为指定格式
def convert_image_to_fn(img_type, image):
    if not exists(img_type) or image.mode == img_type:
        return image

    return image.convert(img_type)

# 如果路径没有后缀,则添加后缀
def append_if_no_suffix(path: str, suffix: str):
    path = Path(path)

    if path.suffix == '':
        path = path.parent / (path.name + suffix)

    assert path.suffix == suffix, f'{str(path)} needs to have suffix {suffix}'

    return str(path)

# 通道到图像模式的映射
CHANNEL_TO_MODE = {
    1: 'L',
    3: 'RGB',
    4: 'RGBA'
}

# 图像相关的辅助函数和数据集

# 图像数据集类
class ImageDataset(Dataset):
    def __init__(
        self,
        folder,
        image_size,
        channels = 3,
        convert_image_to = None,
        exts = ['jpg', 'jpeg', 'png']
    ):
        super().__init__()
        folder = Path(folder)
        assert folder.is_dir(), f'{str(folder)} must be a folder containing images'
        self.folder = folder

        self.image_size = image_size

        exts = exts + [ext.upper() for ext in exts]
        self.paths = [p for ext in exts for p in folder.glob(f'**/*.{ext}')]

        print(f'{len(self.paths)} training samples found at {folder}')

        if exists(channels) and not exists(convert_image_to):
            convert_image_to = CHANNEL_TO_MODE.get(channels)

        self.transform = T.Compose([
            T.Lambda(partial(convert_image_to_fn, convert_image_to)),
            T.Resize(image_size, antialias = True),
            T.RandomHorizontalFlip(),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

# 张量的形状 (channels, frames, height, width) -> gif

# 处理读取和写入 gif

# 逐帧读取图像
def seek_all_images(img: Tensor, channels = 3):
    mode = CHANNEL_TO_MODE.get(channels)

    assert exists(mode), f'channels {channels} invalid'

    i = 0
    while True:
        try:
            img.seek(i)
            yield img.convert(mode)
        except EOFError:
            break
        i += 1

# 张量的形状 (channels, frames, height, width) -> gif

# 将视频张量转换为 gif
@beartype
def video_tensor_to_gif(
    tensor: Tensor,
    path: str,
    duration = 120,
    loop = 0,
    optimize = True
):
    path = append_if_no_suffix(path, '.gif')
    images = map(T.ToPILImage(), tensor.unbind(dim = 1))
    first_img, *rest_imgs = images
    first_img.save(str(path), save_all = True, append_images = rest_imgs, duration = duration, loop = loop, optimize = optimize)
    return images

# gif -> 张量 (channels, frame, height, width)

# 将 gif 转换为张量
def gif_to_tensor(
    path: str,
    channels = 3,
    transform = T.ToTensor()
):
    img = Image.open(path)
    tensors = tuple(map(transform, seek_all_images(img, channels = channels)))
    return torch.stack(tensors, dim = 1)

# 处理读取和写入 mp4

# 将视频转换为张量
def video_to_tensor(
    path: str,              # 视频文件的路径,需要导入的视频
    num_frames = -1,        # 要存储在输出张量中的帧数,默认为-1表示存储所有帧
    crop_size = None        # 裁剪尺寸,默认为None表示不进行裁剪
# 定义一个函数,将视频文件转换为张量
def video_to_tensor(path: str) -> Tensor:  # 返回形状为 (1, 通道数, 帧数, 高度, 宽度) 的张量

    # 使用 OpenCV 打开视频文件
    video = cv2.VideoCapture(path)

    frames = []  # 存储视频帧的列表
    check = True

    # 循环读取视频帧
    while check:
        check, frame = video.read()

        if not check:
            continue

        # 如果存在裁剪尺寸,则对帧进行中心裁剪
        if exists(crop_size):
            frame = crop_center(frame, *pair(crop_size))

        # 将帧重新排列为 (1, ...) 的形状并添加到 frames 列表中
        frames.append(rearrange(frame, '... -> 1 ...'))

    # 将帧列表转换为 numpy 数组,然后合并帧并转换为 numpy 数组
    frames = np.array(np.concatenate(frames[:-1], axis=0))
    frames = rearrange(frames, 'f h w c -> c f h w')

    # 将 numpy 数组转换为 PyTorch 张量并转换为浮点数类型
    frames_torch = torch.tensor(frames).float()

    # 将张量值归一化到 [0, 1] 范围
    frames_torch /= 255.
    # 将张量沿着第一个维度翻转,从 BGR 格式转换为 RGB 格式
    frames_torch = frames_torch.flip(dims=(0,))

    # 返回指定数量的帧数
    return frames_torch[:, :num_frames, :, :]

# 定义一个函数,将张量转换为视频文件
@beartype
def tensor_to_video(
    tensor: Tensor,        # PyTorch 视频张量
    path: str,             # 要保存的视频路径
    fps=25,                # 保存视频的帧率
    video_format='MP4V'    # 视频格式,默认为 MP4
):
    # 如果路径没有后缀,则添加 .mp4 后缀
    path = append_if_no_suffix(path, '.mp4')

    # 将张量移动到 CPU
    tensor = tensor.cpu()

    # 获取张量的帧数、高度和宽度
    num_frames, height, width = tensor.shape[-3:]

    # 使用指定的视频格式创建 VideoWriter 对象
    fourcc = cv2.VideoWriter_fourcc(*video_format)
    video = cv2.VideoWriter(str(path), fourcc, fps, (width, height))

    frames = []  # 存储视频帧的列表

    # 遍历每一帧,将张量转换为 numpy 数组并写入视频
    for idx in range(num_frames):
        numpy_frame = tensor[:, idx, :, :].numpy()
        numpy_frame = np.uint8(rearrange(numpy_frame, 'c h w -> h w c'))
        video.write(numpy_frame)

    # 释放 VideoWriter 对象
    video.release()

    # 关闭所有 OpenCV 窗口
    cv2.destroyAllWindows()

    return video

# 定义一个函数,对图像进行中心裁剪
def crop_center(
    img: Tensor,  # 输入图像张���
    cropx: int,   # 最终图像在 x 方向上的长度
    cropy: int    # 最终图像在 y 方向上的长度
) -> Tensor:      # 返回裁剪后的图像张量
    y, x, c = img.shape
    startx = x // 2 - cropx // 2
    starty = y // 2 - cropy // 2
    return img[starty:(starty + cropy), startx:(startx + cropx), :]

# 视频数据集类
class VideoDataset(Dataset):
    def __init__(
        self,
        folder,              # 视频文件夹路径
        image_size,          # 图像尺寸
        channels=3,          # 通道数,默认为 3
        num_frames=17,       # 帧数,默认为 17
        force_num_frames=True,  # 是否强制指定帧数,默认为 True
        exts=['gif', 'mp4']  # 视频文件扩展名列表,默认为 ['gif', 'mp4']
    ):
        super().__init__()
        folder = Path(folder)
        assert folder.is_dir(), f'{str(folder)} must be a folder containing videos'
        self.folder = folder

        self.image_size = image_size
        self.channels = channels
        self.paths = [p for ext in exts for p in folder.glob(f'**/*.{ext}')]

        print(f'{len(self.paths)} training samples found at {folder}')

        # 定义图像转换操作
        self.transform = T.Compose([
            T.Resize(image_size, antialias=True),
            T.CenterCrop(image_size)
        ])

        # 定义将视频路径转换为张量的函数
        self.gif_to_tensor = partial(gif_to_tensor, channels=self.channels, transform=self.transform)
        self.mp4_to_tensor = partial(video_to_tensor, crop_size=self.image_size)

        # 定义将帧数转换为指定数量的函数
        self.cast_num_frames_fn = partial(cast_num_frames, frames=num_frames) if force_num_frames else identity

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        ext = path.suffix
        path_str = str(path)

        if ext == '.gif':
            tensor = self.gif_to_tensor(path_str)
        elif ext == '.mp4':
            tensor = self.mp4_to_tensor(path_str)
            frames = tensor.unbind(dim=1)
            tensor = torch.stack([*map(self.transform, frames)], dim=1)
        else:
            raise ValueError(f'unknown extension {ext}')

        return self.cast_num_frames_fn(tensor)

# 重写数据加载器以能够整理张量和字符串
def collate_tensors_and_strings(data):
    if is_bearable(data, List[Tensor]):
        return (torch.stack(data),)

    data = zip(*data)
    output = []
    # 遍历数据列表中的每个元素
    for datum in data:
        # 检查数据是否为可接受的类型(元组中包含张量)
        if is_bearable(datum, Tuple[Tensor, ...]):
            # 如果是,则将张量堆叠成一个张量
            datum = torch.stack(datum)
        # 检查数据是否为可接受的类型(元组中包含字符串)
        elif is_bearable(datum, Tuple[str, ...]):
            # 如果是,则将元组转换为列表
            datum = list(datum)
        else:
            # 如果数据类型不符合要求,则引发值错误异常
            raise ValueError('detected invalid type being passed from dataset')

        # 将处理后的数据添加到输出列表中
        output.append(datum)

    # 将输出列表转换为元组并返回
    return tuple(output)
# 定义一个函数DataLoader,接受任意数量的位置参数和关键字参数
def DataLoader(*args, **kwargs):
    # 返回PytorchDataLoader对象,使用指定的参数和自定义的collate函数
    return PytorchDataLoader(*args, collate_fn = collate_tensors_and_strings, **kwargs)

.\lucidrains\magvit2-pytorch\magvit2_pytorch\magvit2_pytorch.py

# 导入必要的库
import copy
from pathlib import Path
from math import log2, ceil, sqrt
from functools import wraps, partial

import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList
from torch.autograd import grad as torch_grad

import torchvision
from torchvision.models import VGG16_Weights

from collections import namedtuple

# 导入自定义模块
from vector_quantize_pytorch import LFQ, FSQ
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
from beartype import beartype
from beartype.typing import Union, Tuple, Optional, List
from magvit2_pytorch.attend import Attend
from magvit2_pytorch.version import __version__
from gateloop_transformer import SimpleGateLoopLayer
from taylor_series_linear_attention import TaylorSeriesLinearAttn
from kornia.filters import filter3d

import pickle

# helper

# 检查变量是否存在
def exists(v):
    return v is not None

# 返回默认值
def default(v, d):
    return v if exists(v) else d

# 安全获取列表中的元素
def safe_get_index(it, ind, default = None):
    if ind < len(it):
        return it[ind]
    return default

# 将输入转换为元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 返回输入本身
def identity(t, *args, **kwargs):
    return t

# 检查一个数是否可以被另一个数整除
def divisible_by(num, den):
    return (num % den) == 0

# 将输入打包成指定模式
def pack_one(t, pattern):
    return pack([t], pattern)

# 将输入解包成指定模式
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 在张量的末尾添加指定维度
def append_dims(t, ndims: int):
    return t.reshape(*t.shape, *((1,) * ndims))

# 检查一个数是否为奇数
def is_odd(n):
    return not divisible_by(n, 2)

# 删除对象的属性
def maybe_del_attr_(o, attr):
    if hasattr(o, attr):
        delattr(o, attr)

# 将输入转换为元组
def cast_tuple(t, length = 1):
    return t if isinstance(t, tuple) else ((t,) * length)

# tensor helpers

# 对张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1, p = 2)

# 在指定维度上对张量进行填充
def pad_at_dim(t, pad, dim = -1, value = 0.):
    dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
    zeros = ((0, 0) * dims_from_right)
    return F.pad(t, (*zeros, *pad), value = value)

# 从视频中选择指定帧
def pick_video_frame(video, frame_indices):
    batch, device = video.shape[0], video.device
    video = rearrange(video, 'b c f ... -> b f c ...')
    batch_indices = torch.arange(batch, device = device)
    batch_indices = rearrange(batch_indices, 'b -> b 1')
    images = video[batch_indices, frame_indices]
    images = rearrange(images, 'b 1 c ... -> b c ...')
    return images

# gan related

# 计算梯度惩罚
def gradient_penalty(images, output):
    batch_size = images.shape[0]

    gradients = torch_grad(
        outputs = output,
        inputs = images,
        grad_outputs = torch.ones(output.size(), device = images.device),
        create_graph = True,
        retain_graph = True,
        only_inputs = True
    )[0]

    gradients = rearrange(gradients, 'b ... -> b (...)')
    return ((gradients.norm(2, dim = 1) - 1) ** 2).mean()

# Leaky ReLU 激活函数
def leaky_relu(p = 0.1):
    return nn.LeakyReLU(p)

# Hinge 损失函数(判别器)
def hinge_discr_loss(fake, real):
    return (F.relu(1 + fake) + F.relu(1 - real)).mean()

# Hinge 损失函数(生成器)
def hinge_gen_loss(fake):
    return -fake.mean()

# 计算损失对层的梯度
@autocast(enabled = False)
@beartype
def grad_layer_wrt_loss(
    loss: Tensor,
    layer: nn.Parameter
):
    return torch_grad(
        outputs = loss,
        inputs = layer,
        grad_outputs = torch.ones_like(loss),
        retain_graph = True
    )[0].detach()

# helper decorators

# 移除 VGG 属性
def remove_vgg(fn):
    @wraps(fn)
    def inner(self, *args, **kwargs):
        has_vgg = hasattr(self, 'vgg')
        if has_vgg:
            vgg = self.vgg
            delattr(self, 'vgg')

        out = fn(self, *args, **kwargs)

        if has_vgg:
            self.vgg = vgg

        return out
    return inner

# helper classes

# 顺序模块
def Sequential(*modules):
    modules = [*filter(exists, modules)]

    if len(modules) == 0:
        return nn.Identity()

    return nn.Sequential(*modules)

# 残差模块
class Residual(Module):
    @beartype
    def __init__(self, fn: Module):
        super().__init__()
        self.fn = fn
    # 定义一个前向传播函数,接受输入 x 和其他关键字参数
    def forward(self, x, **kwargs):
        # 调用函数 fn 对输入 x 进行处理,并将结果与输入 x 相加后返回
        return self.fn(x, **kwargs) + x
# 一系列张量操作,将张量转换为 (batch, time, feature dimension) 格式,然后再转回来

class ToTimeSequence(Module):
    @beartype
    def __init__(self, fn: Module):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        # 重新排列张量的维度,将其转换为 (batch, ..., feature, channel) 格式
        x = rearrange(x, 'b c f ... -> b ... f c')
        # 打包张量,将其转换为 (batch, ..., feature, channel) 格式
        x, ps = pack_one(x, '* n c')

        # 使用给定的函数对张量进行操作
        o = self.fn(x, **kwargs)

        # 解包张量,将其转换回原始格式
        o = unpack_one(o, ps, '* n c')
        # 重新排列张量的维度,将其转换回原始格式
        return rearrange(o, 'b ... f c -> b c f ...')


class SqueezeExcite(Module):
    # 全局上下文网络 - 基于注意力机制的 Squeeze-Excite 变种 (https://arxiv.org/abs/2012.13375)

    def __init__(
        self,
        dim,
        *,
        dim_out = None,
        dim_hidden_min = 16,
        init_bias = -10
    ):
        super().__init__()
        dim_out = default(dim_out, dim)

        # 创建卷积层,用于计算注意力权重
        self.to_k = nn.Conv2d(dim, 1, 1)
        dim_hidden = max(dim_hidden_min, dim_out // 2)

        # 创建包含卷积层和激活函数的网络结构
        self.net = nn.Sequential(
            nn.Conv2d(dim, dim_hidden, 1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(dim_hidden, dim_out, 1),
            nn.Sigmoid()
        )

        # 初始化网络参数
        nn.init.zeros_(self.net[-2].weight)
        nn.init.constant_(self.net[-2].bias, init_bias)

    def forward(self, x):
        orig_input, batch = x, x.shape[0]
        is_video = x.ndim == 5

        if is_video:
            # 重新排列视频张量的维度
            x = rearrange(x, 'b c f h w -> (b f) c h w')

        # 计算上下文信息
        context = self.to_k(x)

        # 计算注意力权重
        context = rearrange(context, 'b c h w -> b c (h w)').softmax(dim = -1)
        spatial_flattened_input = rearrange(x, 'b c h w -> b c (h w)')

        # 使用注意力权重对输入进行加权求和
        out = einsum('b i n, b c n -> b c i', context, spatial_flattened_input)
        out = rearrange(out, '... -> ... 1')
        gates = self.net(out)

        if is_video:
            # 将结果转换回视频张量的格式
            gates = rearrange(gates, '(b f) c h w -> b c f h w', b = batch)

        return gates * orig_input

# token shifting

class TokenShift(Module):
    @beartype
    def __init__(self, fn: Module):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        # 将输入张量分成两部分
        x, x_shift = x.chunk(2, dim = 1)
        # 在时间维度上进行填充,实现时间维度的位移
        x_shift = pad_at_dim(x_shift, (1, -1), dim = 2)
        # 将两部分张量连接起来
        x = torch.cat((x, x_shift), dim = 1)
        return self.fn(x, **kwargs)

# rmsnorm

class RMSNorm(Module):
    def __init__(
        self,
        dim,
        channel_first = False,
        images = False,
        bias = False
    ):
        super().__init__()
        broadcastable_dims = (1, 1, 1) if not images else (1, 1)
        shape = (dim, *broadcastable_dims) if channel_first else (dim,)

        self.channel_first = channel_first
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(shape))
        self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.

    def forward(self, x):
        # 对输入张量进行 RMS 归一化
        return F.normalize(x, dim = (1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias

class AdaptiveRMSNorm(Module):
    def __init__(
        self,
        dim,
        *,
        dim_cond,
        channel_first = False,
        images = False,
        bias = False
    ):
        super().__init__()
        broadcastable_dims = (1, 1, 1) if not images else (1, 1)
        shape = (dim, *broadcastable_dims) if channel_first else (dim,)

        self.dim_cond = dim_cond
        self.channel_first = channel_first
        self.scale = dim ** 0.5

        # 创建线性层,用于生成 gamma 和 bias
        self.to_gamma = nn.Linear(dim_cond, dim)
        self.to_bias = nn.Linear(dim_cond, dim) if bias else None

        # 初始化线性层参数
        nn.init.zeros_(self.to_gamma.weight)
        nn.init.ones_(self.to_gamma.bias)

        if bias:
            nn.init.zeros_(self.to_bias.weight)
            nn.init.zeros_(self.to_bias.bias)

    @beartype
    # 定义一个前向传播函数,接受输入张量 x 和条件张量 cond
    def forward(self, x: Tensor, *, cond: Tensor):
        # 获取批量大小
        batch = x.shape[0]
        # 断言条件张量的形状为 (batch, self.dim_cond)
        assert cond.shape == (batch, self.dim_cond)

        # 根据条件张量生成 gamma
        gamma = self.to_gamma(cond)

        # 初始化偏置为 0
        bias = 0.
        # 如果存在偏置生成函数
        if exists(self.to_bias):
            # 根据条件张量生成偏置
            bias = self.to_bias(cond)

        # 如果通道在前
        if self.channel_first:
            # 在 gamma 的维度前面添加维度,使其与输入张量 x 的维度相同
            gamma = append_dims(gamma, x.ndim - 2)

            # 如果存在偏置生成函数
            if exists(self.to_bias):
                # 在偏置的维度前面添加维度,使其与输入张量 x 的维度相同
                bias = append_dims(bias, x.ndim - 2)

        # 对输入张量 x 进行归一化,根据通道顺序选择归一化的维度,然后乘以缩放因子 scale 和 gamma,最后加上偏置 bias
        return F.normalize(x, dim = (1 if self.channel_first else -1)) * self.scale * gamma + bias
# 定义一个名为 Attention 的类,继承自 Module 类
class Attention(Module):
    # 初始化函数,接受多个参数
    @beartype
    def __init__(
        self,
        *,
        dim,
        dim_cond: Optional[int] = None,
        causal = False,
        dim_head = 32,
        heads = 8,
        flash = False,
        dropout = 0.,
        num_memory_kv = 4
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 计算内部维度
        dim_inner = dim_head * heads

        # 检查是否需要条件
        self.need_cond = exists(dim_cond)

        # 根据是否需要条件选择不同的归一化方法
        if self.need_cond:
            self.norm = AdaptiveRMSNorm(dim, dim_cond = dim_cond)
        else:
            self.norm = RMSNorm(dim)

        # 构建 QKV 网络
        self.to_qkv = nn.Sequential(
            nn.Linear(dim, dim_inner * 3, bias = False),
            Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads)
        )

        # 断言内存键值对数量大于 0
        assert num_memory_kv > 0
        # 初始化内存键值对
        self.mem_kv = nn.Parameter(torch.randn(2, heads, num_memory_kv, dim_head))

        # 构建 Attend 层
        self.attend = Attend(
            causal = causal,
            dropout = dropout,
            flash = flash
        )

        # 构建输出层
        self.to_out = nn.Sequential(
            Rearrange('b h n d -> b n (h d)'),
            nn.Linear(dim_inner, dim, bias = False)
        )

    # 前向传播函数
    @beartype
    def forward(
        self,
        x,
        mask: Optional[Tensor ] = None,
        cond: Optional[Tensor] = None
    ):
        # 根据是否需要条件选择不同的参数
        maybe_cond_kwargs = dict(cond = cond) if self.need_cond else dict()

        # 对输入进行归一化
        x = self.norm(x, **maybe_cond_kwargs)

        # 获取 QKV
        q, k, v = self.to_qkv(x)

        # 重复内存键值对
        mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = q.shape[0]), self.mem_kv)
        k = torch.cat((mk, k), dim = -2)
        v = torch.cat((mv, v), dim = -2)

        # 进行注意力计算
        out = self.attend(q, k, v, mask = mask)
        return self.to_out(out)

# 定义一个名为 LinearAttention 的类,继承自 Module 类
class LinearAttention(Module):
    """
    using the specific linear attention proposed in https://arxiv.org/abs/2106.09681
    """

    # 初始化函数,接受多个参数
    @beartype
    def __init__(
        self,
        *,
        dim,
        dim_cond: Optional[int] = None,
        dim_head = 8,
        heads = 8,
        dropout = 0.
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 计算内部维度
        dim_inner = dim_head * heads

        # 检查是否需要条件
        self.need_cond = exists(dim_cond)

        # 根据是否需要条件选择不同的归一化方法
        if self.need_cond:
            self.norm = AdaptiveRMSNorm(dim, dim_cond = dim_cond)
        else:
            self.norm = RMSNorm(dim)

        # 构建 TaylorSeriesLinearAttn 层
        self.attn = TaylorSeriesLinearAttn(
            dim = dim,
            dim_head = dim_head,
            heads = heads
        )

    # 前向传播函数
    def forward(
        self,
        x,
        cond: Optional[Tensor] = None
    ):
        # 根据是否需要条件选择不同的参数
        maybe_cond_kwargs = dict(cond = cond) if self.need_cond else dict()

        # 对输入进行归一化
        x = self.norm(x, **maybe_cond_kwargs)

        return self.attn(x)

# 定义一个名为 LinearSpaceAttention 的类,继承自 LinearAttention 类
class LinearSpaceAttention(LinearAttention):
    # 重写前向传播函数
    def forward(self, x, *args, **kwargs):
        # 重新排列输入数据
        x = rearrange(x, 'b c ... h w -> b ... h w c')
        x, batch_ps = pack_one(x, '* h w c')
        x, seq_ps = pack_one(x, 'b * c')

        # 调用父类的前向传播函数
        x = super().forward(x, *args, **kwargs)

        x = unpack_one(x, seq_ps, 'b * c')
        x = unpack_one(x, batch_ps, '* h w c')
        return rearrange(x, 'b ... h w c -> b c ... h w')

# 定义一个名为 SpaceAttention 的类,继承自 Attention 类
class SpaceAttention(Attention):
    # 重写前向传播函数
    def forward(self, x, *args, **kwargs):
        # 重新排列输入数据
        x = rearrange(x, 'b c t h w -> b t h w c')
        x, batch_ps = pack_one(x, '* h w c')
        x, seq_ps = pack_one(x, 'b * c')

        # 调用父类的前向传播函数
        x = super().forward(x, *args, **kwargs)

        x = unpack_one(x, seq_ps, 'b * c')
        x = unpack_one(x, batch_ps, '* h w c')
        return rearrange(x, 'b t h w c -> b c t h w')

# 定义一个名为 TimeAttention 的类,继承自 Attention 类
class TimeAttention(Attention):
    # 重写前向传播函数
    def forward(self, x, *args, **kwargs):
        # 重新排列输入数据
        x = rearrange(x, 'b c t h w -> b h w t c')
        x, batch_ps = pack_one(x, '* t c')

        # 调用父类的前向传播函数
        x = super().forward(x, *args, **kwargs)

        x = unpack_one(x, batch_ps, '* t c')
        return rearrange(x, 'b h w t c -> b c t h w')

# 定义一个名为 GEGLU 的类,继承自 Module 类
class GEGLU(Module):
    # 前向传播函数
    def forward(self, x):
        # 将输入数据分成两部分
        x, gate = x.chunk(2, dim = 1)
        return F.gelu(gate) * x

# 定义一个名为 FeedForward 的类,继承自 Module 类
class FeedForward(Module):
    @beartype
    # 初始化函数,设置神经网络的参数
    def __init__(
        self,
        dim,  # 输入数据的维度
        *,
        dim_cond: Optional[int] = None,  # 条件维度,默认为None
        mult = 4,  # 倍数,默认为4
        images = False  # 是否为图像数据,默认为False
    ):
        super().__init__()  # 调用父类的初始化函数
        # 根据是否为图像数据选择不同的卷积层类
        conv_klass = nn.Conv2d if images else nn.Conv3d

        # 根据条件维度是否存在选择不同的归一化层类
        rmsnorm_klass = RMSNorm if not exists(dim_cond) else partial(AdaptiveRMSNorm, dim_cond = dim_cond)

        # 创建可能的自适应归一化层类
        maybe_adaptive_norm_klass = partial(rmsnorm_klass, channel_first = True, images = images)

        # 计算内部维度
        dim_inner = int(dim * mult * 2 / 3)

        # 初始化归一化层
        self.norm = maybe_adaptive_norm_klass(dim)

        # 初始化神经网络结构
        self.net = Sequential(
            conv_klass(dim, dim_inner * 2, 1),  # 卷积层
            GEGLU(),  # 激活函数
            conv_klass(dim_inner, dim, 1)  # 卷积层
        )

    # 前向传播函数
    @beartype
    def forward(
        self,
        x: Tensor,  # 输入数据张量
        *,
        cond: Optional[Tensor] = None  # 条件张量,默认为None
    ):
        # 根据条件张量是否存在选择不同的参数
        maybe_cond_kwargs = dict(cond = cond) if exists(cond) else dict()

        # 对输入数据进行归一化处理
        x = self.norm(x, **maybe_cond_kwargs)
        return self.net(x)  # 返回神经网络处理后的结果
# 定义一个带有反锯齿下采样的鉴别器(模糊池 Zhang 等人)

class Blur(Module):
    def __init__(self):
        super().__init__()
        # 定义一个张量 f
        f = torch.Tensor([1, 2, 1])
        # 将张量 f 注册为缓冲区
        self.register_buffer('f', f)

    def forward(
        self,
        x,
        space_only = False,
        time_only = False
    ):
        # 断言空间和时间只能选择一个
        assert not (space_only and time_only)

        # 获取张量 f
        f = self.f

        if space_only:
            # 对 f 进行乘法操作
            f = einsum('i, j -> i j', f, f)
            # 重新排列张量 f
            f = rearrange(f, '... -> 1 1 ...')
        elif time_only:
            # 重新排列张量 f
            f = rearrange(f, 'f -> 1 f 1 1')
        else:
            # 对 f 进行乘法操作
            f = einsum('i, j, k -> i j k', f, f, f)
            # 重新排列张量 f
            f = rearrange(f, '... -> 1 ...')

        # 判断输入 x 是否为图像
        is_images = x.ndim == 4

        if is_images:
            # 重新排列输入 x
            x = rearrange(x, 'b c h w -> b c 1 h w')

        # 对输入 x 进行 3D 滤波
        out = filter3d(x, f, normalized = True)

        if is_images:
            # 重新排列输出 out
            out = rearrange(out, 'b c 1 h w -> b c h w')

        return out

class DiscriminatorBlock(Module):
    def __init__(
        self,
        input_channels,
        filters,
        downsample = True,
        antialiased_downsample = True
    ):
        super().__init__()
        # 定义卷积层 conv_res
        self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1))

        # 定义神经网络结构 net
        self.net = nn.Sequential(
            nn.Conv2d(input_channels, filters, 3, padding = 1),
            leaky_relu(),
            nn.Conv2d(filters, filters, 3, padding = 1),
            leaky_relu()
        )

        # 如果需要反锯齿下采样,则定义模糊层 maybe_blur
        self.maybe_blur = Blur() if antialiased_downsample else None

        # 如果需要下采样,则定义下采样层 downsample
        self.downsample = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
            nn.Conv2d(filters * 4, filters, 1)
        ) if downsample else None

    def forward(self, x):
        # 对输入 x 进行卷积操作,得到 res
        res = self.conv_res(x)

        # 对输入 x 进行神经网络结构操作
        x = self.net(x)

        if exists(self.downsample):
            if exists(self.maybe_blur):
                # 如果存在模糊层,则对 x 进行模糊操作
                x = self.maybe_blur(x, space_only = True)

            # 对 x 进行下采样操作
            x = self.downsample(x)

        # 对 x 进行加权求和并缩放操作
        x = (x + res) * (2 ** -0.5)
        return x

class Discriminator(Module):
    @beartype
    def __init__(
        self,
        *,
        dim,
        image_size,
        channels = 3,
        max_dim = 512,
        attn_heads = 8,
        attn_dim_head = 32,
        linear_attn_dim_head = 8,
        linear_attn_heads = 16,
        ff_mult = 4,
        antialiased_downsample = False
    ):
        # 调用父类的构造函数
        super().__init__()
        # 将图像大小转换为元组
        image_size = pair(image_size)
        # 计算图像分辨率的最小值
        min_image_resolution = min(image_size)

        # 计算层数
        num_layers = int(log2(min_image_resolution) - 2)

        blocks = []

        # 计算每一层的维度
        layer_dims = [channels] + [(dim * 4) * (2 ** i) for i in range(num_layers + 1)]
        # 将每一层的维度限制在最大维度内
        layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
        # 将每一层的输入输出维度组成元组
        layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))

        blocks = []
        attn_blocks = []

        image_resolution = min_image_resolution

        for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
            num_layer = ind + 1
            is_not_last = ind != (len(layer_dims_in_out) - 1)

            # 创建判别器块
            block = DiscriminatorBlock(
                in_chan,
                out_chan,
                downsample = is_not_last,
                antialiased_downsample = antialiased_downsample
            )

            # 创建注意力块
            attn_block = Sequential(
                Residual(LinearSpaceAttention(
                    dim = out_chan,
                    heads = linear_attn_heads,
                    dim_head = linear_attn_dim_head
                )),
                Residual(FeedForward(
                    dim = out_chan,
                    mult = ff_mult,
                    images = True
                ))
            )

            blocks.append(ModuleList([
                block,
                attn_block
            ]))

            image_resolution //= 2

        self.blocks = ModuleList(blocks)

        dim_last = layer_dims[-1]

        downsample_factor = 2 ** num_layers
        last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))

        latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last

        # 定义输出层
        self.to_logits = Sequential(
            nn.Conv2d(dim_last, dim_last, 3, padding = 1),
            leaky_relu(),
            Rearrange('b ... -> b (...)'),
            nn.Linear(latent_dim, 1),
            Rearrange('b 1 -> b')
        )

    def forward(self, x):

        # 遍历每个块和注意力块
        for block, attn_block in self.blocks:
            x = block(x)
            x = attn_block(x)

        return self.to_logits(x)
# 定义一个继承自 Module 的类 Conv3DMod,用于实现可调制的卷积,用于在潜变量上进行条件化
class Conv3DMod(Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        dim,
        *,
        spatial_kernel,
        time_kernel,
        causal = True,
        dim_out = None,
        demod = True,
        eps = 1e-8,
        pad_mode = 'zeros'
    ):
        super().__init__()
        dim_out = default(dim_out, dim)

        self.eps = eps

        # 断言空间和时间卷积核为奇数
        assert is_odd(spatial_kernel) and is_odd(time_kernel)

        self.spatial_kernel = spatial_kernel
        self.time_kernel = time_kernel

        # 根据是否因果,设置时间填充
        time_padding = (time_kernel - 1, 0) if causal else ((time_kernel // 2,) * 2)

        self.pad_mode = pad_mode
        self.padding = (*((spatial_kernel // 2,) * 4), *time_padding)
        self.weights = nn.Parameter(torch.randn((dim_out, dim, time_kernel, spatial_kernel, spatial_kernel)))

        self.demod = demod

        # 初始化权重
        nn.init.kaiming_normal_(self.weights, a = 0, mode = 'fan_in', nonlinearity = 'selu')

    # 前向传播函数
    @beartype
    def forward(
        self,
        fmap,
        cond: Tensor
    ):
        """
        notation

        b - batch
        n - convs
        o - output
        i - input
        k - kernel
        """

        b = fmap.shape[0]

        # 准备用于调制的权重
        weights = self.weights

        # 进行调制和解调制,类似 stylegan2 中的操作
        cond = rearrange(cond, 'b i -> b 1 i 1 1 1')

        weights = weights * (cond + 1)

        if self.demod:
            inv_norm = reduce(weights ** 2, 'b o i k0 k1 k2 -> b o 1 1 1 1', 'sum').clamp(min = self.eps).rsqrt()
            weights = weights * inv_norm

        fmap = rearrange(fmap, 'b c t h w -> 1 (b c) t h w')

        weights = rearrange(weights, 'b o ... -> (b o) ...')

        fmap = F.pad(fmap, self.padding, mode = self.pad_mode)
        fmap = F.conv3d(fmap, weights, groups = b)

        return rearrange(fmap, '1 (b o) ... -> b o ...', b = b)

# 定义一个继承自 Module 的类 SpatialDownsample2x,用于进行空间下采样
class SpatialDownsample2x(Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        kernel_size = 3,
        antialias = False
    ):
        super().__init__()
        dim_out = default(dim_out, dim)
        self.maybe_blur = Blur() if antialias else identity
        self.conv = nn.Conv2d(dim, dim_out, kernel_size, stride = 2, padding = kernel_size // 2)

    # 前向传播函数
    def forward(self, x):
        x = self.maybe_blur(x, space_only = True)

        x = rearrange(x, 'b c t h w -> b t c h w')
        x, ps = pack_one(x, '* c h w')

        out = self.conv(x)

        out = unpack_one(out, ps, '* c h w')
        out = rearrange(out, 'b t c h w -> b c t h w')
        return out

# 定义一个继承自 Module 的类 TimeDownsample2x,用于进行时间下采样
class TimeDownsample2x(Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        kernel_size = 3,
        antialias = False
    ):
        super().__init__()
        dim_out = default(dim_out, dim)
        self.maybe_blur = Blur() if antialias else identity
        self.time_causal_padding = (kernel_size - 1, 0)
        self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride = 2)

    # 前向传播函数
    def forward(self, x):
        x = self.maybe_blur(x, time_only = True)

        x = rearrange(x, 'b c t h w -> b h w c t')
        x, ps = pack_one(x, '* c t')

        x = F.pad(x, self.time_causal_padding)
        out = self.conv(x)

        out = unpack_one(out, ps, '* c t')
        out = rearrange(out, 'b h w c t -> b c t h w')
        return out

# 定义一个继承自 Module 的类 SpatialUpsample2x,用于进行空间上采样
class SpatialUpsample2x(Module):
    def __init__(
        self,
        dim,
        dim_out = None
    ):
        super().__init__()
        dim_out = default(dim_out, dim)
        conv = nn.Conv2d(dim, dim_out * 4, 1)

        self.net = nn.Sequential(
            conv,
            nn.SiLU(),
            Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1 = 2, p2 = 2)
        )

        self.init_conv_(conv)
    # 初始化卷积层的权重和偏置
    def init_conv_(self, conv):
        # 获取卷积层的输出通道数、输入通道数、高度和宽度
        o, i, h, w = conv.weight.shape
        # 创建一个与卷积层权重相同形状的张量
        conv_weight = torch.empty(o // 4, i, h, w)
        # 使用 Kaiming 初始化方法初始化权重
        nn.init.kaiming_uniform_(conv_weight)
        # 将权重张量重复4次,扩展为4倍的输出通道数
        conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')

        # 将初始化好的权重复制给卷积层的权重
        conv.weight.data.copy_(conv_weight)
        # 初始化卷积层的偏置为零
        nn.init.zeros_(conv.bias.data)

    # 前向传播函数
    def forward(self, x):
        # 重新排列输入张量的维度,将通道维度放到第二个位置
        x = rearrange(x, 'b c t h w -> b t c h w')
        # 将输入张量打包成一个元组,每个元素为一个通道的数据
        x, ps = pack_one(x, '* c h w')

        # 将打包后的输入张量传入网络进行前向传播
        out = self.net(x)

        # 将网络输出解包,恢复为原始形状
        out = unpack_one(out, ps, '* c h w')
        # 重新排列输出张量的维度,将通道维度放回最后一个位置
        out = rearrange(out, 'b t c h w -> b c t h w')
        # 返回前向传播结果
        return out
# 定义一个类 TimeUpsample2x,继承自 Module 类
class TimeUpsample2x(Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        dim_out = None
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 如果未指定输出维度,则默认与输入维度相同
        dim_out = default(dim_out, dim)
        # 创建一个 1 维卷积层,输入维度为 dim,输出维度为 dim_out * 2,卷积核大小为 1
        conv = nn.Conv1d(dim, dim_out * 2, 1)

        # 使用 nn.Sequential 定义网络结构
        self.net = nn.Sequential(
            conv,
            nn.SiLU(),  # 使用 SiLU 激活函数
            Rearrange('b (c p) t -> b c (t p)', p = 2)  # 重新排列张量维度
        )

        # 初始化卷积层的权重
        self.init_conv_(conv)

    # 初始化卷积层的权重
    def init_conv_(self, conv):
        o, i, t = conv.weight.shape
        # 创建一个与卷积层权重相同形状的张量
        conv_weight = torch.empty(o // 2, i, t)
        # 使用 kaiming_uniform_ 方法初始化权重
        nn.init.kaiming_uniform_(conv_weight)
        # 将权重张量重复一次
        conv_weight = repeat(conv_weight, 'o ... -> (o 2) ...')

        # 将初始化后的权重赋值给卷积层
        conv.weight.data.copy_(conv_weight)
        # 将偏置项初始化为零
        nn.init.zeros_(conv.bias.data)

    # 前向传播函数
    def forward(self, x):
        # 重新排列输入张量的维度
        x = rearrange(x, 'b c t h w -> b h w c t')
        # 打包输入张量
        x, ps = pack_one(x, '* c t')

        # 网络前向传播
        out = self.net(x)

        # 解包输出张量
        out = unpack_one(out, ps, '* c t')
        # 重新排列输出张量的维度
        out = rearrange(out, 'b h w c t -> b c t h w')
        return out

# 定义一个函数 SameConv2d,用于创建相同维度的二维卷积层
def SameConv2d(dim_in, dim_out, kernel_size):
    kernel_size = cast_tuple(kernel_size, 2)
    padding = [k // 2 for k in kernel_size]
    return nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding)

# 定义一个类 CausalConv3d,继承自 Module 类
class CausalConv3d(Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        chan_in,
        chan_out,
        kernel_size: Union[int, Tuple[int, int, int]],
        pad_mode = 'constant',
        **kwargs
    ):
        # 调用父类的初始化函数
        super().__init__()
        kernel_size = cast_tuple(kernel_size, 3)

        time_kernel_size, height_kernel_size, width_kernel_size = kernel_size

        # 确保高度和宽度的卷积核大小为奇数
        assert is_odd(height_kernel_size) and is_odd(width_kernel_size)

        dilation = kwargs.pop('dilation', 1)
        stride = kwargs.pop('stride', 1)

        # 设置时间维度的填充大小
        time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
        height_pad = height_kernel_size // 2
        width_pad = width_kernel_size // 2

        self.pad_mode = pad_mode
        self.time_pad = time_pad
        self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)

        stride = (stride, 1, 1)
        dilation = (dilation, 1, 1)
        # 创建一个三维卷积层
        self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, **kwargs)

    # 前向传播函数
    def forward(self, x):
        # 根据填充模式选择填充方式
        pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant'

        # 对输入张量进行填充
        x = F.pad(x, self.time_causal_padding, mode = pad_mode)
        return self.conv(x)

# 定义一个函数 ResidualUnit,用于创建残差单元
@beartype
def ResidualUnit(
    dim,
    kernel_size: Union[int, Tuple[int, int, int]],
    pad_mode: str = 'constant'
):
    # 构建残差单元网络结构
    net = Sequential(
        CausalConv3d(dim, dim, kernel_size, pad_mode = pad_mode),
        nn.ELU(),  # 使用 ELU 激活函数
        nn.Conv3d(dim, dim, 1),
        nn.ELU(),
        SqueezeExcite(dim)
    )

    return Residual(net)

# 定义一个类 ResidualUnitMod,继承自 Module 类
@beartype
class ResidualUnitMod(Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        kernel_size: Union[int, Tuple[int, int, int]],
        *,
        dim_cond,
        pad_mode: str = 'constant',
        demod = True
    ):
        super().__init__()
        kernel_size = cast_tuple(kernel_size, 3)
        time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
        assert height_kernel_size == width_kernel_size

        # 线性层,用于将条件信息转换为相同维度
        self.to_cond = nn.Linear(dim_cond, dim)

        # 创建一个 Conv3DMod 层
        self.conv = Conv3DMod(
            dim = dim,
            spatial_kernel = height_kernel_size,
            time_kernel = time_kernel_size,
            causal = True,
            demod = demod,
            pad_mode = pad_mode
        )

        # 创建一个 1x1x1 三维卷积层
        self.conv_out = nn.Conv3d(dim, dim, 1)

    # 前向传播函数
    @beartype
    def forward(
        self,
        x,
        cond: Tensor,
    ):
        res = x
        cond = self.to_cond(cond)

        # 进行卷积操作
        x = self.conv(x, cond = cond)
        x = F.elu(x)
        x = self.conv_out(x)
        x = F.elu(x)
        return x + res

# 定义一个类 CausalConvTranspose3d,继承自 Module 类
    # 初始化函数,定义了一个卷积转置层
    def __init__(
        self,
        chan_in,
        chan_out,
        kernel_size: Union[int, Tuple[int, int, int]],
        *,
        time_stride,
        **kwargs
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 将 kernel_size 转换为三元组
        kernel_size = cast_tuple(kernel_size, 3)

        # 分别获取时间、高度和宽度的卷积核大小
        time_kernel_size, height_kernel_size, width_kernel_size = kernel_size

        # 断言高度卷积核大小和宽度卷积核大小为奇数
        assert is_odd(height_kernel_size) and is_odd(width_kernel_size)

        # 设置上采样因子为时间步长
        self.upsample_factor = time_stride

        # 计算高度和宽度的填充值
        height_pad = height_kernel_size // 2
        width_pad = width_kernel_size // 2

        # 设置步长和填充值
        stride = (time_stride, 1, 1)
        padding = (0, height_pad, width_pad)

        # 创建一个三维卷积转置层
        self.conv = nn.ConvTranspose3d(chan_in, chan_out, kernel_size, stride, padding = padding, **kwargs)

    # 前向传播函数
    def forward(self, x):
        # 断言输入张量 x 的维度为 5
        assert x.ndim == 5
        # 获取时间维度的大小
        t = x.shape[2]

        # 对输入张量进行卷积转置操作
        out = self.conv(x)

        # 裁剪输出张量的时间维度,保留 t * 上采样因子 个时间步
        out = out[..., :(t * self.upsample_factor), :, :]
        # 返回处理后的输出张量
        return out
# 定义了 LossBreakdown 命名元组,包含了不同损失的分解信息
LossBreakdown = namedtuple('LossBreakdown', [
    'recon_loss',
    'lfq_aux_loss',
    'quantizer_loss_breakdown',
    'perceptual_loss',
    'adversarial_gen_loss',
    'adaptive_adversarial_weight',
    'multiscale_gen_losses',
    'multiscale_gen_adaptive_weights'
])

# 定义了 DiscrLossBreakdown 命名元组,包含了鉴别器损失的分解信息
DiscrLossBreakdown = namedtuple('DiscrLossBreakdown', [
    'discr_loss',
    'multiscale_discr_losses',
    'gradient_penalty'
])

# 定义了 VideoTokenizer 类,继承自 Module 类
class VideoTokenizer(Module):
    # 初始化方法
    @beartype
    def __init__(
        self,
        *,
        image_size,
        layers: Tuple[Union[str, Tuple[str, int]], ...] = (
            'residual',
            'residual',
            'residual'
        ),
        residual_conv_kernel_size = 3,
        num_codebooks = 1,
        codebook_size: Optional[int] = None,
        channels = 3,
        init_dim = 64,
        max_dim = float('inf'),
        dim_cond = None,
        dim_cond_expansion_factor = 4.,
        input_conv_kernel_size: Tuple[int, int, int] = (7, 7, 7),
        output_conv_kernel_size: Tuple[int, int, int] = (3, 3, 3),
        pad_mode: str = 'constant',
        lfq_entropy_loss_weight = 0.1,
        lfq_commitment_loss_weight = 1.,
        lfq_diversity_gamma = 2.5,
        quantizer_aux_loss_weight = 1.,
        lfq_activation = nn.Identity(),
        use_fsq = False,
        fsq_levels: Optional[List[int]] = None,
        attn_dim_head = 32,
        attn_heads = 8,
        attn_dropout = 0.,
        linear_attn_dim_head = 8,
        linear_attn_heads = 16,
        vgg: Optional[Module] = None,
        vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
        perceptual_loss_weight = 1e-1,
        discr_kwargs: Optional[dict] = None,
        multiscale_discrs: Tuple[Module, ...] = tuple(),
        use_gan = True,
        adversarial_loss_weight = 1.,
        grad_penalty_loss_weight = 10.,
        multiscale_adversarial_loss_weight = 1.,
        flash_attn = True,
        separate_first_frame_encoding = False
    # 返回属性 device,返回 zero 属性的设备信息
    @property
    def device(self):
        return self.zero.device

    # 类方法,初始化并从路径加载模型
    @classmethod
    def init_and_load_from(cls, path, strict = True):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path), map_location = 'cpu')

        assert 'config' in pkg, 'model configs were not found in this saved checkpoint'

        config = pickle.loads(pkg['config'])
        tokenizer = cls(**config)
        tokenizer.load(path, strict = strict)
        return tokenizer

    # 返回模型参数
    def parameters(self):
        return [
            *self.conv_in.parameters(),
            *self.conv_in_first_frame.parameters(),
            *self.conv_out_first_frame.parameters(),
            *self.conv_out.parameters(),
            *self.encoder_layers.parameters(),
            *self.decoder_layers.parameters(),
            *self.encoder_cond_in.parameters(),
            *self.decoder_cond_in.parameters(),
            *self.quantizers.parameters()
        ]

    # 返回鉴别器参数
    def discr_parameters(self):
        return self.discr.parameters()

    # 复制模型用于评估
    def copy_for_eval(self):
        device = self.device
        vae_copy = copy.deepcopy(self.cpu())

        maybe_del_attr_(vae_copy, 'discr')
        maybe_del_attr_(vae_copy, 'vgg')
        maybe_del_attr_(vae_copy, 'multiscale_discrs')

        vae_copy.eval()
        return vae_copy.to(device)

    # 返回模型状态字典
    @remove_vgg
    def state_dict(self, *args, **kwargs):
        return super().state_dict(*args, **kwargs)

    # 加载模型状态字典
    @remove_vgg
    def load_state_dict(self, *args, **kwargs):
        return super().load_state_dict(*args, **kwargs)

    # 保存模型
    def save(self, path, overwrite = True):
        path = Path(path)
        assert overwrite or not path.exists(), f'{str(path)} already exists'

        pkg = dict(
            model_state_dict = self.state_dict(),
            version = __version__,
            config = self._configs
        )

        torch.save(pkg, str(path))
    # 加载模型参数
    def load(self, path, strict = True):
        # 将路径转换为 Path 对象
        path = Path(path)
        # 断言路径存在
        assert path.exists()

        # 加载模型参数
        pkg = torch.load(str(path))
        state_dict = pkg.get('model_state_dict')
        version = pkg.get('version')

        # 断言模型参数存在
        assert exists(state_dict)

        # 如果版本信息存在,则打印加载的 tokenizer 版本信息
        if exists(version):
            print(f'loading checkpointed tokenizer from version {version}')

        # 加载模型参数到当前模型
        self.load_state_dict(state_dict, strict = strict)

    # 编码视频
    @beartype
    def encode(
        self,
        video: Tensor,
        quantize = False,
        cond: Optional[Tensor] = None,
        video_contains_first_frame = True
    ):
        # 是否单独编码第一帧
        encode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame

        # 是否填充视频
        if video_contains_first_frame:
            video_len = video.shape[2]

            video = pad_at_dim(video, (self.time_padding, 0), value = 0., dim = 2)
            video_packed_shape = [torch.Size([self.time_padding]), torch.Size([]), torch.Size([video_len - 1])]

        # 条件编码
        assert (not self.has_cond) or exists(cond), '`cond` must be passed into tokenizer forward method since conditionable layers were specified'

        if exists(cond):
            assert cond.shape == (video.shape[0], self.dim_cond)

            cond = self.encoder_cond_in(cond)
            cond_kwargs = dict(cond = cond)

        # 初始卷积
        if encode_first_frame_separately:
            pad, first_frame, video = unpack(video, video_packed_shape, 'b c * h w')
            first_frame = self.conv_in_first_frame(first_frame)

        video = self.conv_in(video)

        if encode_first_frame_separately:
            video, _ = pack([first_frame, video], 'b c * h w')
            video = pad_at_dim(video, (self.time_padding, 0), dim = 2)

        # 编码器层
        for fn, has_cond in zip(self.encoder_layers, self.has_cond_across_layers):

            layer_kwargs = dict()

            if has_cond:
                layer_kwargs = cond_kwargs

            video = fn(video, **layer_kwargs)

        maybe_quantize = identity if not quantize else self.quantizers

        return maybe_quantize(video)

    # 从编码索引解码
    @beartype
    def decode_from_code_indices(
        self,
        codes: Tensor,
        cond: Optional[Tensor] = None,
        video_contains_first_frame = True
    ):
        assert codes.dtype in (torch.long, torch.int32)

        if codes.ndim == 2:
            video_code_len = codes.shape[-1]
            assert divisible_by(video_code_len, self.fmap_size ** 2), f'flattened video ids must have a length ({video_code_len}) that is divisible by the fmap size ({self.fmap_size}) squared ({self.fmap_size ** 2})'

            codes = rearrange(codes, 'b (f h w) -> b f h w', h = self.fmap_size, w = self.fmap_size)

        quantized = self.quantizers.indices_to_codes(codes)

        return self.decode(quantized, cond = cond, video_contains_first_frame = video_contains_first_frame)

    # 解码
    @beartype
    def decode(
        self,
        quantized: Tensor,
        cond: Optional[Tensor] = None,
        video_contains_first_frame = True
        ):
        # 检查是否需要单独解码第一帧
        decode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame

        # 获取批量大小
        batch = quantized.shape[0]

        # 条件输入,如果需要的话
        assert (not self.has_cond) or exists(cond), '`cond` must be passed into tokenizer forward method since conditionable layers were specified'

        if exists(cond):
            assert cond.shape == (batch, self.dim_cond)

            # 将条件输入传入条件编码器
            cond = self.decoder_cond_in(cond)
            cond_kwargs = dict(cond = cond)

        # 解码器层

        x = quantized

        for fn, has_cond in zip(self.decoder_layers, reversed(self.has_cond_across_layers)):

            layer_kwargs = dict()

            if has_cond:
                layer_kwargs = cond_kwargs

            # 逐层解码
            x = fn(x, **layer_kwargs)

        # 转换为像素

        if decode_first_frame_separately:
            left_pad, xff, x = x[:, :, :self.time_padding], x[:, :, self.time_padding], x[:, :, (self.time_padding + 1):]

            # 对输出进行卷积
            out = self.conv_out(x)
            outff = self.conv_out_first_frame(xff)

            # 将第一帧和其余帧打包
            video, _ = pack([outff, out], 'b c * h w')

        else:
            # 对输出进行卷积
            video = self.conv_out(x)

            # 如果视频包含第一帧,则移除填充
            if video_contains_first_frame:
                video = video[:, :, self.time_padding:]

        return video

    @torch.no_grad()
    def tokenize(self, video):
        # 设置为评估模式
        self.eval()
        return self.forward(video, return_codes = True)

    @beartype
    def forward(
        self,
        video_or_images: Tensor,
        cond: Optional[Tensor] = None,
        return_loss = False,
        return_codes = False,
        return_recon = False,
        return_discr_loss = False,
        return_recon_loss_only = False,
        apply_gradient_penalty = True,
        video_contains_first_frame = True,
        adversarial_loss_weight = None,
        multiscale_adversarial_loss_weight = None
# 主要类定义

class MagViT2(Module):
    # 初始化方法
    def __init__(self):
        # 调用父类的初始化方法
        super().__init__()

    # 前向传播方法
    def forward(self, x):
        # 返回输入数据 x,即不做任何处理
        return x
posted @ 2024-06-28 14:01  绝不原创的飞龙  阅读(26)  评论(0编辑  收藏  举报