HybridBlock supports forwarding with both Symbol and NDArray

gluon/image_classification.py代码有这么一段:

复制代码
import mxnet as mx
from mxnet.gluon.model_zoo import vision as models
...
net = models.get_model('vgg11', context, opt)     # 这里得到的是gluon的模型
...
data = mx.sym.var('data')      # symbol
out = net(data)                # 这里把symbol传到了gluon模型里?
softmax = mx.sym.SoftmaxOutput(out, name='softmax')
mod = mx.mod.Module(softmax, context=context)
mx.viz.plot_network(softmax).view()    #打印出来symbol模型结构
复制代码

symbol传到了gluon模型里?深感疑惑,看了一下api发现可以的:

首先是看到get_model方法返回的是gluon.HybridBlock类型

 

而HybridBlock是同时支持Symbol和NDArray的!:

 

 

所以这么搞是没问题的。 

所以这个代码其实有3种训练模式: symbolic(符号式-类似tf)、hybrid(混合式-mx特性)、imperative(交互式-类似pytorch)

复制代码
    if opt.mode == 'symbolic':          # 符号式
        data = mx.sym.var('data')
        if opt.dtype == 'float16':
            data = mx.sym.Cast(data=data, dtype=np.float16)
        out = net(data)
        if opt.dtype == 'float16':
            out = mx.sym.Cast(data=out, dtype=np.float32)
        softmax = mx.sym.SoftmaxOutput(out, name='softmax')
        mod = mx.mod.Module(softmax, context=context)
        train_data, val_data = get_data_iters(dataset, batch_size, opt)
        mod.fit(train_data,
                eval_data=val_data,
                num_epoch=opt.epochs,
                kvstore=kv,
                batch_end_callback = mx.callback.Speedometer(batch_size, max(1, opt.log_interval)),
                epoch_end_callback = mx.callback.do_checkpoint('image-classifier-%s'% opt.model),
                optimizer = 'sgd',
                optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum, 'multi_precision': True},
                initializer = mx.init.Xavier(magnitude=2))
        mod.save_parameters('image-classifier-%s-%d-final.params'%(opt.model, opt.epochs))
    else:
        if opt.mode == 'hybrid':        # 混合式
            net.hybridize()
        train(opt, context)             # 交互式 
复制代码

 

posted @   三年一梦  阅读(274)  评论(0编辑  收藏  举报
编辑推荐:
· .NET Core 对象分配(Alloc)底层原理浅谈
· 聊一聊 C#异步 任务延续的三种底层玩法
· 敏捷开发:如何高效开每日站会
· 为什么 .NET8线程池 容易引发线程饥饿
· golang自带的死锁检测并非银弹
阅读排行:
· 聊一聊 C#异步 任务延续的三种底层玩法
· 上位机能不能替代PLC呢?
· 2024年终总结:5000 Star,10w 下载量,这是我交出的开源答卷
· 一个适用于 .NET 的开源整洁架构项目模板
· .NET Core:架构、特性和优势详解
点击右上角即可分享
微信分享提示