pytorch如何批量reshape,如何每batch_size进行reshape
假设我有一个tensor,它的batch_size是2:
tensor = torch.randn([2, 6])
print(tensor.shape)
输出是
torch.Size([2, 6])
其中tensor.shape[0]
代表tensor的batch_size
如果我要把其中每个Batch的数据从6转换成[2,3],怎么写?循环遍历tensor然后循环内用reshape吗?不!
看下面的操作,很简单:
tensor = torch.randn([2, 6])
print(tensor)
tensor = tensor.reshape(tensor.shape[0], 2, 3) # 将每个批次的数据转换成2,3的形状
print(tensor)
tensor = tensor.reshape(tensor.shape[0], 6) # 恢复原来的形状
print(tensor)
输出是:
tensor([[-0.7920, -0.7887, -0.7362, 0.2238, 0.3442, 1.5486],
[ 1.7589, -0.3414, 0.4499, -0.0228, 0.4032, 0.3730]])
tensor([[[-0.7920, -0.7887, -0.7362],
[ 0.2238, 0.3442, 1.5486]],
[[ 1.7589, -0.3414, 0.4499],
[-0.0228, 0.4032, 0.3730]]])
tensor([[-0.7920, -0.7887, -0.7362, 0.2238, 0.3442, 1.5486],
[ 1.7589, -0.3414, 0.4499, -0.0228, 0.4032, 0.3730]])
Process finished with exit code 0
但是要注意!需要改变形状的tensor里面的东西要符合要求!数量不够会报错!
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列01:轻松3步本地部署deepseek,普通电脑可用
· 25岁的心里话
· 按钮权限的设计及实现