[cnn][julia]Flux实现卷积神经网络cnn预测手写MNIST

julia_Flux

1.导入Flux.jl和其他所需工具包

using Flux, MLDatasets, Statistics
using Flux: onehotbatch, onecold, logitcrossentropy, params
using MLDatasets: MNIST
using Base.Iterators: partition
using Printf, BSON
using Images
using CUDA,LinearAlgebra,Random
CUDA.allowscalar(false)
# 为学习率、batch、epoch和保存文件的路径设置默认值
Base.@kwdef mutable struct TrainArgs
  lr::Float64 = 3e-3
  epochs::Int = 20
  batch_size = 128
end
TrainArgs

训练集测试集

MNIST 60000训练集 10000测试集

train_data= MLDatasets.MNIST(split=:train)
dataset MNIST:
  metadata  =>    Dict{String, Any} with 3 entries
  split     =>    :train
  features  =>    28×28×60000 Array{Float32, 3}
  targets   =>    60000-element Vector{Int64}
test_data= MLDatasets.MNIST(split=:test)
dataset MNIST:
  metadata  =>    Dict{String, Any} with 3 entries
  split     =>    :test
  features  =>    28×28×10000 Array{Float32, 3}
  targets   =>    10000-element Vector{Int64}
train_data.features
28×28×60000 Array{Float32, 3}:
[:, :, 1] =
 0.0  0.0  0.0  0.0  0.0  0.0       …  0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.215686  0.533333  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0       …  0.67451   0.992157  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.886275  0.992157  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.992157  0.992157  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.992157  0.831373  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.992157  0.529412  0.0  0.0  0.0
 ⋮                        ⋮         ⋱                      ⋮         
 0.0  0.0  0.0  0.0  0.0  0.101961     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.65098   …  0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  1.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.968627     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.498039     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0       …  0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0

[:, :, 2] =
 0.0  0.0  0.0  0.0  0.0       0.0       …  0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0       …  0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0980392  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.501961   0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.988235   0.0  0.0  0.0  0.0
 ⋮                             ⋮         ⋱                  ⋮         
 0.0  0.0  0.0  0.0  0.196078  0.929412     0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0       …  0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0       …  0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0

[:, :, 3] =
 0.0  0.0  0.0  0.0  0.0  0.0       0.0       …  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0       0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0       0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0       0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0       0.243137     0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0       0.317647  …  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0       0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0       0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0       0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0       0.0          0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                        ⋮                   ⋱                 ⋮         
 0.0  0.0  0.0  0.0  0.0  0.0       0.0          0.6  0.6  0.6  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.262745  0.470588  …  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.909804  0.705882     0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.152941  0.152941     0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0       0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0       0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0       0.0       …  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0       0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0       0.0          0.0  0.0  0.0  0.0  0.0  0.0

;;; … 

[:, :, 59998] =
 0.0  0.0  0.0  0.0  0.0  0.0       …  0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0       …  0.45098   0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.941176  0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.988235  0.615686  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.639216  0.992157  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.576471  0.992157  0.0  0.0  0.0
 ⋮                        ⋮         ⋱                      ⋮         
 0.0  0.0  0.0  0.0  0.0  0.376471     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.47451   …  0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.835294     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  1.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  1.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.47451      0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0       …  0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0

[:, :, 59999] =
 0.0  0.0  0.0       0.0       0.0        …  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0       0.0       0.0           0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0       0.0       0.0           0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0       0.0       0.0           0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0       0.0       0.0           0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0       0.0       0.0        …  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0       0.0       0.0           0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0       0.0       0.0           0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0       0.0       0.0           0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0       0.0       0.0           0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                                        ⋱                 ⋮         
 0.0  0.0  0.752941  0.988235  0.745098      0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.901961  0.756863  0.0352941  …  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.105882  0.105882  0.0           0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0       0.0       0.0           0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0       0.0       0.0           0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0       0.0       0.0           0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0       0.0       0.0        …  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0       0.0       0.0           0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0       0.0       0.0           0.0  0.0  0.0  0.0  0.0  0.0

[:, :, 60000] =
 0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0     0.101961  0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0     0.898039  0.286275  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0     0.976471  0.756863  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0     0.690196  0.772549  0.0  0.0  0.0
 ⋮                        ⋮    ⋱                      ⋮         
 0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0  0.0  0.0
X_train = reshape(train_data.features, 28,28,1,:)
size(X_train)
pic_1 = X_train[:,:,:,1]
28×28×1 Array{Float32, 3}:
[:, :, 1] =
 0.0  0.0  0.0  0.0  0.0  0.00.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.215686  0.533333  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.00.67451   0.992157  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.886275  0.992157  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.992157  0.992157  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.992157  0.831373  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.992157  0.529412  0.0  0.0  0.0
 ⋮                        ⋮         ⋱                      ⋮         
 0.0  0.0  0.0  0.0  0.0  0.101961     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.650980.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  1.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.968627     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.498039     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.00.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
img = pic_1[:,:]
28×28 Matrix{Float32}:
 0.0  0.0  0.0  0.0  0.0  0.00.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.215686  0.533333  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.00.67451   0.992157  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.886275  0.992157  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.992157  0.992157  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.992157  0.831373  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.992157  0.529412  0.0  0.0  0.0
 ⋮                        ⋮         ⋱                      ⋮         
 0.0  0.0  0.0  0.0  0.0  0.101961     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.650980.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  1.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.968627     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.498039     0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.00.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0  0.0
colorview(Gray,img')
train_data.targets
60000-element Vector{Int64}:
 5
 0
 4
 1
 9
 2
 1
 3
 1
 4
 ⋮
 2
 9
 5
 1
 8
 3
 5
 6
 8
Flux.onehotbatch(train_data.targets, 0:9)
10×60000 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
 ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  …  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  1  ⋅  ⋅  11  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  11     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅
 ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  …  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  1  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  1
 ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  11  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
function loader(data::MNIST=train_data; batchsize::Int=64)
    x4dim = reshape(data.features, 28,28,1,:)
    yhot = Flux.onehotbatch(data.targets, 0:9)
    Flux.DataLoader((x4dim, yhot); batchsize, shuffle=true) |> gpu
end

loader()

x1, y1 = first(loader()) 
x1
28×28×1×64 CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}:
[:, :, 1, 1] =
 0.0  0.0  0.0  0.0  0.0  0.0  0.0       …  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0       …  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.247059     0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                        ⋮              ⋱                 ⋮         
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0       …  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0       …  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0  0.0  0.0  0.0  0.0  0.0

[:, :, 1, 2] =
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                        ⋮         ⋱                      ⋮         
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0

[:, :, 1, 3] =
 0.0  0.0  0.0  0.0  0.0       0.0       …  0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0       …  0.0509804  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.435294   0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.627451   0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.533333   0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.270588   0.0  0.0  0.0  0.0
 ⋮                             ⋮         ⋱                  ⋮         
 0.0  0.0  0.0  0.0  0.27451   0.992157     0.0745098  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.101961  0.937255  …  0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.478431     0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0       …  0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0        0.0  0.0  0.0  0.0

;;;; … 

[:, :, 1, 62] =
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                        ⋮                   ⋱                 ⋮         
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0

[:, :, 1, 63] =
 0.0  0.0  0.0  0.0  0.0  0.0  0.0       …  0.0       0.0       0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0       …  0.0       0.0       0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.396078  0.0       0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.976471  0.882353  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.988235  0.988235  0.0  0.0
 ⋮                        ⋮              ⋱            ⋮              
 0.0  0.0  0.0  0.0  0.0  0.0  0.345098     0.0       0.0       0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0       …  0.0       0.0       0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.886275     0.0       0.0       0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.466667     0.0       0.0       0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.662745  …  0.0       0.0       0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.439216     0.0       0.0       0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0          0.0       0.0       0.0  0.0

[:, :, 1, 64] =
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                        ⋮              ⋱                      ⋮         
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
y1
10×64 OneHotMatrix(::CuArray{UInt32, 1, CUDA.Mem.DeviceBuffer}) with eltype Bool:
 ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  11  …  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅
 ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  1  ⋅  ⋅  1  1  ⋅  ⋅  ⋅  1
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  …  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 1  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  1  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  1  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  11  1  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅
 ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅

卷积神经网络模型 LeNet,通常用于手写数字识别等任务。让我们来看一下每一层的作用:

  • 第一个卷积层:使用 Conv((5, 5), 1=>6, relu),对输入图像进行 6 次滤波操作,每个过滤器都是 5x5 大小,生成 6 个输出通道。ReLU 激活函数将输出非线性化;
  • 第一个最大池化层:使用 MaxPool((2, 2)),对输入的单通道或多通道特征图进行 2x2 的最大池化操作,从每个 2x2 的窗口中选出最大值,减少特征图的空间大小和计算复杂度;
  • 第二个卷积层:使用 Conv((5, 5), 6=>16, relu),输入为第一层的 6 个输出通道,经过 16 次 5x5 的卷积得到 16 个输出通道,ReLU 对其进行非线性化处理;
  • 第二个最大池化层:使用 MaxPool((2, 2)),同样对输入的 16 通道特征图进行 2x2 的最大池化操作,减小特征图大小;
  • Flatten 层:使用 Flux.flatten,将大小为 (16, 5, 5) 的张量拉伸为一维向量;
    全连接层 1:使用 Dense(256 => 120, relu),输入为拉伸后的一维向量,输出为大小为 120 的特征向量,ReLU 对其进行非线性化处理;
  • 全连接层 2:使用 Dense(120 => 84, relu),输入为大小为 120 的特征向量,输出为大小为 84 的特征向量,ReLU 对其进行非线性化处理;
  • 输出层:使用 Dense(84 => 10),输入为大小为 84 的特征向量,输出为大小为 10 的得分向量,每个元素表示样本属于该类别的概率。

28281>24246>12126>8816>4416>256

model = Chain(
    Conv((5, 5), 1=>6, relu),
    MaxPool((2, 2)),
    Conv((5, 5), 6=>16, relu),
    MaxPool((2, 2)),
    Flux.flatten,
    Dense(256 => 120, relu),
    Dense(120 => 84, relu), 
    Dense(84 => 10),
) |> gpu


Chain(
  Conv((5, 5), 1 => 6, relu),           # 156 parameters
  MaxPool((2, 2)),
  Conv((5, 5), 6 => 16, relu),          # 2_416 parameters
  MaxPool((2, 2)),
  Flux.flatten,
  Dense(256 => 120, relu),              # 30_840 parameters
  Dense(120 => 84, relu),               # 10_164 parameters
  Dense(84 => 10),                      # 850 parameters
)                   # Total: 10 arrays, 44_426 parameters, 2.086 KiB.
#先把x1放进去预测试试看
y1hat = model(x1) 
#行:对应的10个输出
#列:对应一个batch中64个样本
10×64 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
  0.0216377  -0.157416    -0.0696039  …  -0.113559     -0.169696
 -0.0367052  -0.022017     0.0686541     -0.000658836   0.0592685
  0.0315908   0.073026    -0.0919517     -0.0104487    -0.126225
 -0.127688   -0.101221    -0.0616018     -0.0647386    -0.02525
  0.0745182   0.00518273   0.106297       0.1257        0.0266679
  0.0660713  -0.018785    -0.0910264  …  -0.0812425    -0.126134
 -0.395019   -0.343077    -0.303919      -0.334892     -0.37829
 -0.0184298  -0.0638247   -0.0148924     -0.0276188    -0.0715621
  0.219471    0.22844      0.207235       0.126554      0.202388
 -0.172887   -0.0813897   -0.126973      -0.220921     -0.268266
y_hat = hcat(Flux.onecold(y1hat, 0:9), Flux.onecold(y1, 0:9)) |> cpu
64×2 Matrix{Int64}:
 8  6
 8  9
 8  2
 8  0
 8  3
 8  8
 8  6
 8  8
 8  8
 8  6
 ⋮  
 4  4
 8  7
 8  6
 8  4
 8  4
 1  1
 1  6
 8  8
 8  4
size(y_hat)[1]
64
result = hcat(y_hat,zeros(size(y_hat)[1],1))
64×3 Matrix{Float64}:
 8.0  6.0  0.0
 8.0  9.0  0.0
 8.0  2.0  0.0
 8.0  0.0  0.0
 8.0  3.0  0.0
 8.0  8.0  0.0
 8.0  6.0  0.0
 8.0  8.0  0.0
 8.0  8.0  0.0
 8.0  6.0  0.04.0  4.0  0.0
 8.0  7.0  0.0
 8.0  6.0  0.0
 8.0  4.0  0.0
 8.0  4.0  0.0
 1.0  1.0  0.0
 1.0  6.0  0.0
 8.0  8.0  0.0
 8.0  4.0  0.0

# 遍历原始矩阵,根据需要设置的第三列
for i in 1:size(y_hat, 1)
    if y_hat[i, 1] == y_hat[i, 2]
      result[i, 3] = 1
    end
end
result
64×3 Matrix{Float64}:
 8.0  6.0  0.0
 8.0  9.0  0.0
 8.0  2.0  0.0
 8.0  0.0  0.0
 8.0  3.0  0.0
 8.0  8.0  1.0
 8.0  6.0  0.0
 8.0  8.0  1.0
 8.0  8.0  1.0
 8.0  6.0  0.04.0  4.0  1.0
 8.0  7.0  0.0
 8.0  6.0  0.0
 8.0  4.0  0.0
 8.0  4.0  0.0
 1.0  1.0  1.0
 1.0  6.0  0.0
 8.0  8.0  1.0
 8.0  4.0  0.0
check_display = [result[:,1] result[:,2] result[:,3]]
# 预测值 | 真实值 | 是否正确
vscodedisplay(check_display)
using Statistics: mean

function loss_and_accuracy(model, data::MNIST=test_data)
    (x,y) = only(loader(data; batchsize=length(data)))  # make one big batch
    ŷ = model(x)
    loss = Flux.logitcrossentropy(ŷ, y)  # did not include softmax in the model
    acc = round(100 * mean(Flux.onecold(ŷ) .== Flux.onecold(y)); digits=2)
    (; loss, acc, split=data.split)  # return a NamedTuple
end
loss_and_accuracy (generic function with 2 methods)
@show loss_and_accuracy(model); 
loss_and_accuracy(model) = (loss = 2.3237605f0, acc = 11.78, split = :test)

记录执行一个使用LeNet神经网络在数据集上进行分类任务的训练过程,并记录每个 epoch 的训练损失、准确率和测试准确率。具体解释如下:

  • settings 为一个命名元组(named tuple),包含了该模型的相关设置,例如学习率 eta、权重衰减 lambda、批量大小 batchsize 等。
  • train_log 为空数组,后面将用它来存储每个 epoch 的日志信息。
  • opt_rule 使用 Adam 优化器和权重衰减规则构成一个优化器链式表达式(opt_group),用于更新神经网络的参数。
  • opt_state 利用 Flux.setup 函数根据模型的初始权重设置好优化器的状态变量。
  • 进入 for 循环,分别对于每个 epoch 训练,这里用 @time 宏可以显示程序执行时间。循环采用 loader() 函数在 batch 中随机加载数据。grads 包含了所有参数的梯度信息,通过 Flux.update! 函数将梯度传递给优化器进行参数更新。
  • 当 epoch 为单数时,储存当前 epoch 的 train_loss、准确率(accuracy)以及 test_loss、test_acc(测试集上的损失和准确率),并将该 epoch 的信息作为一个特定的命名元组 nt 保存到 train_log 数组中。
    训练过程结束,程序退出。
  • 实现了一个基本的分类任务模型,并对其进行训练。在训练过程中,每个 epoch 的信息都被记录和储存,以便于后续的统计和分析。
settings = (;
    eta = 0.001,     # 学习率
    lambda = 1e-2,  # 在使用正则化(Regularization)方法优化神经网络的过程中,通常会添加一个权值衰减项(Weight Decay),它是一种标准正则化方法,旨在防止模型过度拟合并提高泛化性能。该方法通过对网络权重施加额外的约束,使得训练过程中权重逐渐趋向于较小的值。
    batchsize = 128,
    epochs = 30,
)
train_log = []
Any[]
opt_rule = OptimiserChain(WeightDecay(settings.lambda), Adam(settings.eta))  #优化器
opt_state = Flux.setup(opt_rule, model);   #配置优化器和模型的函数。该函数接受两个参数:一个优化器对象和一个模型对象model
using JLD2

for epoch in 1:settings.epochs
    @time for (x,y) in loader(batchsize=settings.batchsize)
        grads = Flux.gradient(m -> Flux.logitcrossentropy(m(x), y), model) #计算梯度
        Flux.update!(opt_state, model, grads[1]) #更新模型参数
    end


    if epoch % 2 == 1
      loss, acc, _ = loss_and_accuracy(model,train_data)
      test_loss, test_acc, _ = loss_and_accuracy(model, test_data)
      @info "logging:" epoch acc test_acc
      nt = (; epoch, loss, acc, test_loss, test_acc)  # make a NamedTuple
      push!(train_log, nt) #在训练集和测试集上进行训练,记录并输出每个 epoch 中训练的 loss 和 accuracy,并将结果以 NamedTuple 的形式保存在一个数组中。
    end

  if epoch % 5 == 0
    JLD2.jldsave("mymodel"; model_state = Flux.state(model) |> cpu) #保存模型
    println("saved to ", "mymodel", " after ", epoch, " epochs")
  end
end


 23.558565 seconds (34.64 M allocations: 2.349 GiB, 3.66% gc time, 75.51% compilation time)


 Info: logging:
   epoch = 1
   acc = 96.07
   test_acc = 96.28
 @ Main c:\Users\HP\Desktop\模式识别\julia_flux\cnnn.ipynb:15


  1.207501 seconds (1.75 M allocations: 333.873 MiB, 7.44% gc time)
  1.046763 seconds (1.75 M allocations: 333.984 MiB, 5.96% gc time)


 Info: logging:
   epoch = 3
   acc = 96.86
   test_acc = 97.23
 @ Main c:\Users\HP\Desktop\模式识别\julia_flux\cnnn.ipynb:15


  1.056170 seconds (1.75 M allocations: 333.873 MiB, 6.33% gc time)
  1.022417 seconds (1.74 M allocations: 333.682 MiB, 4.46% gc time)


 Info: logging:
   epoch = 5
   acc = 97.2
   test_acc = 97.39
 @ Main c:\Users\HP\Desktop\模式识别\julia_flux\cnnn.ipynb:15


saved to mymodel after 5 epochs
  1.310450 seconds (1.76 M allocations: 334.547 MiB, 6.40% gc time)
  1.046871 seconds (1.75 M allocations: 333.978 MiB, 5.59% gc time)


 Info: logging:
   epoch = 7
   acc = 97.4
   test_acc = 97.68
 @ Main c:\Users\HP\Desktop\模式识别\julia_flux\cnnn.ipynb:15


  1.039215 seconds (1.75 M allocations: 333.876 MiB, 6.44% gc time)
  1.010181 seconds (1.74 M allocations: 333.688 MiB, 4.35% gc time)


 Info: logging:
   epoch = 9
   acc = 97.51
   test_acc = 97.76
 @ Main c:\Users\HP\Desktop\模式识别\julia_flux\cnnn.ipynb:15


  1.054506 seconds (1.75 M allocations: 333.883 MiB, 6.38% gc time)
saved to mymodel after 10 epochs
  1.010604 seconds (1.74 M allocations: 333.673 MiB, 4.28% gc time)


 Info: logging:
   epoch = 11
   acc = 97.32
   test_acc = 97.37
 @ Main c:\Users\HP\Desktop\模式识别\julia_flux\cnnn.ipynb:15


  1.073143 seconds (1.75 M allocations: 333.890 MiB, 6.37% gc time)
  1.021915 seconds (1.74 M allocations: 333.682 MiB, 4.27% gc time)


 Info: logging:
   epoch = 13
   acc = 97.64
   test_acc = 98.03
 @ Main c:\Users\HP\Desktop\模式识别\julia_flux\cnnn.ipynb:15


  1.062875 seconds (1.75 M allocations: 333.876 MiB, 6.31% gc time)
  1.035603 seconds (1.74 M allocations: 333.675 MiB, 4.21% gc time)
saved to mymodel after 15 epochs


 Info: logging:
   epoch = 15
   acc = 97.59
   test_acc = 97.91
 @ Main c:\Users\HP\Desktop\模式识别\julia_flux\cnnn.ipynb:15


  1.053837 seconds (1.75 M allocations: 333.882 MiB, 6.38% gc time)
  1.025212 seconds (1.74 M allocations: 333.676 MiB, 4.17% gc time)


 Info: logging:
   epoch = 17
   acc = 97.46
   test_acc = 97.85
 @ Main c:\Users\HP\Desktop\模式识别\julia_flux\cnnn.ipynb:15


  1.079404 seconds (1.75 M allocations: 333.880 MiB, 6.11% gc time)
  1.034417 seconds (1.74 M allocations: 333.682 MiB, 4.37% gc time)


 Info: logging:
   epoch = 19
   acc = 97.44
   test_acc = 97.66
 @ Main c:\Users\HP\Desktop\模式识别\julia_flux\cnnn.ipynb:15


  1.059716 seconds (1.75 M allocations: 333.880 MiB, 6.29% gc time)
saved to mymodel after 20 epochs
  1.056283 seconds (1.74 M allocations: 333.680 MiB, 4.25% gc time)


 Info: logging:
   epoch = 21
   acc = 97.74
   test_acc = 97.94
 @ Main c:\Users\HP\Desktop\模式识别\julia_flux\cnnn.ipynb:15


  1.089269 seconds (1.75 M allocations: 333.883 MiB, 6.26% gc time)
  1.012205 seconds (1.74 M allocations: 333.676 MiB, 4.36% gc time)


 Info: logging:
   epoch = 23
   acc = 97.55
   test_acc = 97.69
 @ Main c:\Users\HP\Desktop\模式识别\julia_flux\cnnn.ipynb:15


  1.059416 seconds (1.75 M allocations: 333.876 MiB, 6.33% gc time)
  1.026613 seconds (1.74 M allocations: 333.679 MiB, 4.77% gc time)
saved to mymodel after 25 epochs


 Info: logging:
   epoch = 25
   acc = 97.31
   test_acc = 97.63
 @ Main c:\Users\HP\Desktop\模式识别\julia_flux\cnnn.ipynb:15


  1.058952 seconds (1.75 M allocations: 333.879 MiB, 6.39% gc time)
  1.003798 seconds (1.74 M allocations: 333.676 MiB, 3.89% gc time)


 Info: logging:
   epoch = 27
   acc = 97.19
   test_acc = 97.44
 @ Main c:\Users\HP\Desktop\模式识别\julia_flux\cnnn.ipynb:15


  1.065478 seconds (1.75 M allocations: 333.880 MiB, 6.31% gc time)
  1.011256 seconds (1.74 M allocations: 333.679 MiB, 4.28% gc time)


 Info: logging:
   epoch = 29
   acc = 97.52
   test_acc = 97.73
 @ Main c:\Users\HP\Desktop\模式识别\julia_flux\cnnn.ipynb:15


  1.054528 seconds (1.75 M allocations: 333.890 MiB, 5.99% gc time)
saved to mymodel after 30 epochs
@show train_log;
train_log = Any[(epoch = 1, loss = 0.13786168f0, acc = 96.07, test_loss = 0.12124871f0, test_acc = 96.28), (epoch = 3, loss = 0.10716068f0, acc = 96.86, test_loss = 0.0937544f0, test_acc = 97.23), (epoch = 5, loss = 0.10237007f0, acc = 97.2, test_loss = 0.09357353f0, test_acc = 97.39), (epoch = 7, loss = 0.09652825f0, acc = 97.4, test_loss = 0.08653756f0, test_acc = 97.68), (epoch = 9, loss = 0.08729916f0, acc = 97.51, test_loss = 0.08031578f0, test_acc = 97.76), (epoch = 11, loss = 0.09641348f0, acc = 97.32, test_loss = 0.091933824f0, test_acc = 97.37), (epoch = 13, loss = 0.08514932f0, acc = 97.64, test_loss = 0.07468332f0, test_acc = 98.03), (epoch = 15, loss = 0.09086585f0, acc = 97.59, test_loss = 0.08101739f0, test_acc = 97.91), (epoch = 17, loss = 0.09147936f0, acc = 97.46, test_loss = 0.081585675f0, test_acc = 97.85), (epoch = 19, loss = 0.09023266f0, acc = 97.44, test_loss = 0.080876224f0, test_acc = 97.66), (epoch = 21, loss = 0.08312355f0, acc = 97.74, test_loss = 0.07422219f0, test_acc = 97.94), (epoch = 23, loss = 0.086933106f0, acc = 97.55, test_loss = 0.079941735f0, test_acc = 97.69), (epoch = 25, loss = 0.09566926f0, acc = 97.31, test_loss = 0.087120935f0, test_acc = 97.63), (epoch = 27, loss = 0.10115273f0, acc = 97.19, test_loss = 0.093344204f0, test_acc = 97.44), (epoch = 29, loss = 0.08723685f0, acc = 97.52, test_loss = 0.081499815f0, test_acc = 97.73)]
y1hat = model(x1)
y_hat_new = hcat(Flux.onecold(y1hat, 0:9), Flux.onecold(y1, 0:9)) |> cpu
result_new = hcat(y_hat_new,zeros(size(y_hat_new)[1],1))
for i in 1:size(y_hat_new, 1)
  if y_hat_new[i, 1] == y_hat_new[i, 2]
    result_new[i, 3] = 1
  end
end
result_new
64×3 Matrix{Float64}:
 6.0  6.0  1.0
 9.0  9.0  1.0
 2.0  2.0  1.0
 0.0  0.0  1.0
 3.0  3.0  1.0
 8.0  8.0  1.0
 6.0  6.0  1.0
 8.0  8.0  1.0
 8.0  8.0  1.0
 6.0  6.0  1.04.0  4.0  1.0
 7.0  7.0  1.0
 6.0  6.0  1.0
 4.0  4.0  1.0
 4.0  4.0  1.0
 1.0  1.0  1.0
 6.0  6.0  1.0
 8.0  8.0  1.0
 4.0  4.0  1.0
check_display_new = [result_new[:,1] result_new[:,2] result_new[:,3]]
# 预测值 | 真实值 | 是否正确
vscodedisplay(check_display_new)
using ImageCore, ImageInTerminal
xtest, ytest = only(loader(test_data, batchsize=length(test_data))); #得到测试所有样本 图片+label
size(xtest)
(28, 28, 1, 10000)
index = 17
get_image = xtest[:,:,1,index] .|> Gray |> transpose |> cpu
colorview(Gray,get_image)

y_label = Flux.onecold(ytest, 0:9)|> cpu
y_label[index]
5

查找分类最不确定的图像。
首先,在概率的每一列中,寻找概率最大的一个。
然后,在所有图像中寻找最低的该概率,并确定其索引。

ptest = softmax(model(xtest))
max_p = maximum(ptest; dims=1)
_, i = findmin(vec(max_p))
(0.19307606f0, 1018)
xtest[:,:,1,i] .|> Gray |> transpose |> cpu
using JLD2
loaded_state = JLD2.load("mymodel", "model_state"); #加载模型
model2 = Flux.@autosize (28, 28, 1, 1) Chain(
    Conv((5, 5), 1=>6, relu),
    MaxPool((2, 2)),
    Conv((5, 5), _=>16, relu),
    MaxPool((2, 2)),
    Flux.flatten,
    Dense(_ => 120, relu),
    Dense(_ => 84, relu), 
    Dense(_ => 10),
)
model2 = Flux.loadmodel!(model2, loaded_state) |>cpu
Chain(
  Conv((5, 5), 1 => 6, relu),           # 156 parameters
  MaxPool((2, 2)),
  Conv((5, 5), 6 => 16, relu),          # 2_416 parameters
  MaxPool((2, 2)),
  Flux.flatten,
  Dense(256 => 120, relu),              # 30_840 parameters
  Dense(120 => 84, relu),               # 10_164 parameters
  Dense(84 => 10),                      # 850 parameters
)                   # Total: 10 arrays, 44_426 parameters, 174.867 KiB.
@show model2(cpu(x1)) ≈ cpu(model(x1))
model2(cpu(x1)) ≈ cpu(model(x1)) = true



true
using Images

# 加载图像并将其转换为 28x28 矩阵
img = load("4.png")
img_28_28 = imresize(img,(28,28))
img_gray = Gray.(Gray.(img_28_28) .> 0.5)
input_img = map(Float32, img_gray')
28×28 Matrix{Float32}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.00.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.00.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  1.0  1.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  1.0  1.0  1.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                        ⋮              ⋱                      ⋮         
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.00.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.00.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
colorview(Gray,input_img')
input = reshape(input_img, 28,28,1,:) |>cpu
28×28×1×1 Array{Float32, 4}:
[:, :, 1, 1] =
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.00.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.00.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  1.0  1.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  1.0  1.0  1.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                        ⋮              ⋱                      ⋮         
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.00.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.00.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
colorview(Gray,input) #实际输入模型
load_result = model2(input) |>cpu
10×1 Matrix{Float32}:
 -4.2211957
 -0.56267506
 -0.40410087
 -3.0764287
  9.327405
 -1.5239846
  0.65510744
  0.44357768
 -2.260195
  1.5781431
output = softmax(load_result) |>cpu
10×1 Matrix{Float32}:
 1.3047671f-6
 5.063005f-5
 5.9330312f-5
 4.0992018f-6
 0.9991155
 1.9360517f-5
 0.0001711137
 0.00013849005
 9.272249f-6
 0.00043067936
Flux.onecold(output, 0:9) #得到预测标签
1-element Vector{Int64}:
 4

julia图像切割

using Images,ImageFiltering, ImageView, ImageMorphology,ImageSegmentation
# 读取图像
img = load("qq_hand.jpg")
binary_img = map(Float32,1 * (Gray.(img) .> 0.5))   #图片二值化 膨胀腐蚀(预处理)
carplate_img_binary = Gray.(Gray.(binary_img) .< 0.5)
carplate_img_binary_c = closing(closing(closing(carplate_img_binary)))
carplate_img_binary_e = erode(erode(carplate_img_binary_c))
input_img = map(Float32, carplate_img_binary_e)
column = sum(input_img, dims=1)  # 沿着列的方向将矩阵
raw = sum(input_img, dims=2)
# 查找第一个非零元素的索引
raw_first_index = findfirst(raw .!= 0)[1]
# 查找最后一个非零元素的索引
raw_last_index = findlast(raw .!= 0)[1]
column_index = findall(column .!= 0 )
column_index_list = Float32[]
push!(column_index_list,column_index[1][2])
for i in 2:length(column_index)-1
  if (column_index[i-1][2] + 1 != column_index[i][2]) || (column_index[i+1][2] - 1 != column_index[i][2])
    push!(column_index_list,column_index[i][2])
  end
end
push!(column_index_list,column_index[length(column_index)][2])
cut_column = reshape(column_index_list,(2,10))
test_pic = []
for i in 1:size(cut_column)[2]
  push!(test_pic,input_img[raw_first_index:raw_last_index, Int32(cut_column[:,i][1]):Int32(cut_column[:,i][2])])
end
column
1×2145 Matrix{Float32}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.00.0  0.0  0.0  0.0  0.0  0.0  0.0
using Plots
plot(1:length(raw),raw[:])
using Plots
plot(1:length(column),column[:])
using Images #处理成能放入网络预测的28*28*1的矩阵
result = []
fillcolor = fill(RGB{Float32}(0), (1, 1))[1]
for i in 1:length(test_pic)
  if size(test_pic[i])[1] > 4 * size(test_pic[i])[2]
    z1 = zeros(Float32,size(test_pic[i])[1],2 * size(test_pic[i])[2])
    test_pic[i] = hcat(z1,hcat(test_pic[i],z1))
  end
  test_pic[i] = imresize(test_pic[i],(28,28),Pad=true,padcolor=fillcolor,stretch=false)
  input = reshape(test_pic[i]', 28,28,1,:) |>cpu
  load_result = model2(input) |>cpu #放入预测
  output = softmax(load_result) |>cpu
  push!(result,Flux.onecold(output, 0:9))
end

colorview(Gray,test_pic[4])
result
10-element Vector{Any}:
 [1]
 [4]
 [5]
 [0]
 [5]
 [3]
 [3]
 [4]
 [3]
 [2]
img

posted @   J1nWan  阅读(171)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· C#/.NET/.NET Core技术前沿周刊 | 第 29 期(2025年3.1-3.9)
· 从HTTP原因短语缺失研究HTTP/2和HTTP/3的设计差异
点击右上角即可分享
微信分享提示