bert文本分类模型保存为savedmodel方式
默认bert是ckpt,在进行后期优化和部署时,savedmodel方式更加友好写。
train完成后,调用如下函数:
def save_savedmodel(estimator, serving_dir, seq_length, is_tpu_estimator): feature_map = { "input_ids": tf.placeholder(tf.int32, shape=[None, seq_length], name='input_ids'), "input_mask": tf.placeholder(tf.int32, shape=[None, seq_length], name='input_mask'), "segment_ids": tf.placeholder(tf.int32, shape=[None, seq_length], name='segment_ids'), "label_ids": tf.placeholder(tf.int32, shape=[None], name='label_ids'), } serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(feature_map) estimator.export_savedmodel(serving_dir, serving_input_receiver_fn, strip_default_attrs=True) print("保存savedmodel")
estimator:estimator = Estimator(model_fn=model_fn,params={},config=run_config)
serving_dir:存储目录
seq_length:样本长度
is_tpu_estimator: tpu标志位
时刻记着自己要成为什么样的人!
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步