mxnet系列 全连接层代码阅读
全连接操作(全连接层)也具有前向和反向。代码 解析如下
virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data, const std::vector<OpReqType> &req, const std::vector<TBlob> &out_data, const std::vector<TBlob> &aux_args) { using namespace mshadow; using namespace mshadow::expr; if (req[fullc::kOut] == kNullOp) return; CHECK_EQ(req[fullc::kOut], kWriteTo); size_t expected = param_.no_bias ? 2 : 3; CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1); // TODO(bing): check the BLAS Handle, be careful // maybe need blas handle from context // TODO(bing): judge shape to remove flatten op Stream<xpu> *s = ctx.get_stream<xpu>(); #if defined(__CUDACC__) CHECK_EQ(s->blas_handle_ownership_, Stream<xpu>::OwnHandle) << "Must init CuBLAS handle in stream"; #endif // __CUDACC__ const TShape& ishape = in_data[fullc::kData].shape_; const TShape& oshape = out_data[fullc::kOut].shape_; Tensor<xpu, 2, DType> data = in_data[fullc::kData].get_with_shape<xpu, 2, DType>( //输入 Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s); Tensor<xpu, 2, DType> wmat = in_data[fullc::kWeight].get<xpu, 2, DType>(s); //权重 Tensor<xpu, 2, DType> out = out_data[fullc::kOut].get_with_shape<xpu, 2, DType>( //输出 Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s); out = dot(data, wmat.T()); //点乘 if (!param_.no_bias) { Tensor<xpu, 1, DType> bias = in_data[fullc::kBias].get<xpu, 1, DType>(s); out += repmat(bias, data.size(0)); } } virtual void Backward(const OpContext &ctx, const std::vector<TBlob> &out_grad, const std::vector<TBlob> &in_data, const std::vector<TBlob> &out_data, const std::vector<OpReqType> &req, const std::vector<TBlob> &in_grad, const std::vector<TBlob> &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), 1); size_t expected = param_.no_bias ? 2 : 3; CHECK(in_data.size() == expected && in_grad.size() == expected); CHECK_EQ(req.size(), expected); // TODO(bing): check the BLAS Handle, be careful // maybe need blas handle from context Stream<xpu> *s = ctx.get_stream<xpu>(); const TShape& ishape = in_data[fullc::kData].shape_; const TShape& oshape = out_grad[fullc::kOut].shape_; Tensor<xpu, 2, DType> data = in_data[fullc::kData].get_with_shape<xpu, 2, DType>( //输入 Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s); Tensor<xpu, 2, DType> wmat = in_data[fullc::kWeight].get<xpu, 2, DType>(s); //权重 Tensor<xpu, 2, DType> grad = out_grad[fullc::kOut].get_with_shape<xpu, 2, DType>( //梯度 Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s); #if defined(__CUDACC__) CHECK_EQ(s->blas_handle_ownership_, Stream<xpu>::OwnHandle) << "Must init CuBLAS handle in stream"; #endif // backprop CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace"; // gradient of weight Tensor<xpu, 2, DType> gwmat = in_grad[fullc::kWeight].get<xpu, 2, DType>(s); //权重梯度 Assign(gwmat, req[fullc::kWeight], dot(grad.T(), data)); //求权重梯度 // gradient of bias if (!param_.no_bias) { Tensor<xpu, 1, DType> gbias = in_grad[fullc::kBias].get<xpu, 1, DType>(s);//偏置梯度 Assign(gbias, req[fullc::kBias], sum_rows(grad)); } // gradient of data Tensor<xpu, 2, DType> gdata = in_grad[fullc::kData].get_with_shape<xpu, 2, DType>( //输入梯度 Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s); Assign(gdata, req[fullc::kData], dot(grad, wmat)); //求权重梯度 }