Gram矩阵(pytorch)
风格迁移中广泛使用的
def gram_matrix(y): """ Returns the gram matrix of y (used to compute style loss) """ (b, c, h, w) = y.size() features = y.view(b, c, w * h) features_t = features.transpose(1, 2) #C和w*h转置 gram = features.bmm(features_t) / (c * h * w) #bmm 将features与features_t相乘 return gram