矩阵维度变换--einops库
import einops # 创建一个形状为(batch_size, seq_length, hidden_dim)的张量 tensor = tf.constant([[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]]]) # 使用einops进行维度交换和重塑 reshaped = einops.rearrange(tensor, 'batch seq dim -> batch dim seq') print(reshaped.shape)