torchvision.transforms.ToTensor(),torchvision.trasnsforms.Normalize()
用PyTorch进行神经网络训练时,如果训练用的数据是图像数据,则需要在训练之前对图像进行预处理。以MNIST数据为例:
1 2 3 4 5 6 7 | train_data = torchvision.datasets.MNIST( root = './mnist/' , train = True , transform = torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0] download = True , ) |
transform=torchvision.transforms.ToTensor()起到的作用是把PIL.Image或者numpy.narray数据类型转变为torch.FloatTensor类型,shape是C*H*W,数值范围缩小为[0.0, 1.0]。
如果想把数值范围调整为[-1.0, 1.0],则可加torchvision.transforms.Normalize([mean_channel1,mean_channel2,mean_channel3], [std_channel1,std_channel2,std_channel3]),如果是黑白图像,比如MNIST里的图像,只有一个通道,则mean只需要一个,std也只需要一个。
1 2 3 4 5 6 7 8 9 | im_tfs = torchvision.trasnsforms.Compose([ torchvision.trasnsforms.ToTensor(), torchvision.trasnsforms.Normalize([ 0.5 ], [ 0.5 ]) ])train_data = torchvision.datasets.MNIST( root = './mnist/' , train = True , transform = torchvision.transforms.ToTensor(), download = True , ) |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 单元测试从入门到精通