golang调用tensor flow模型
1. 安装Go版TensorFlow
TensorFlow 提供了一个 Go API,该 API 特别适合加载用 Python 创建的模型并在 Go 应用中运行这些模型。
安装TensorFlow C库
下载地址
解压 :
tar -C $dir -xzf tar_file
添加到动态库:
export LIBRARY_PATH=$LIBRARY_PATH:$dir/lib
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$dir/lib
如果你已经解压到了/usr/local下,则不需要配置LIBRARY_PATH和LD_LIBRARY_PATH,只需要执行sudo ldconfig即可
安装 TensorFlow Go
go get github.com/tensorflow/tensorflow/tensorflow/go
2. 用Python训练tensor flow模型并保存
注意:python的tensor flow版本不能高于go的tensor flow版本,否则go加载模型文件时会报错。
import tensorflow as tf from keras import backend as K sess = tf.Session() K.set_session(sess) def build_deep_cross(): inputs = [] for i, feature_name in enumerate(feature_names): cate_in = Input((1,), name=feature_name) inputs.append(cate_in) # 此处略去很多代码 cross_net = build_cross_net(f_dim_vectors) # CrossNet deep_out = build_dnn(f_dim_vectors, continuous_input) # 深层网络 # 结合CrossNet和深层网络 concat_cross_deep = Concatenate()([cross_net, deep_out]) outputs = Dense(1, activation="sigmoid", name="output_layer")(concat_cross_deep) # 留意模型的inputs和outputs,写Golang时要用 model = Model(inputs=inputs, outputs=outputs) solver = Adam(lr=0.01, decay=0.1) model.compile(optimizer=solver, loss='binary_crossentropy', metrics=['acc']) return model model=build_deep_cross() model.fit(X_train, Y_train, batch_size=256, epochs=10) # 专门为Golang保存一个模型 builder = tf.compat.v1.saved_model.builder.SavedModelBuilder("dcnModel") # 必须为模型打个Tag,否则golang无法加载 builder.add_meta_graph_and_variables(sess, ["myTag"]) # 保存 builder.save()
3. Golang加载tensor flow模型
package main import ( tf "github.com/tensorflow/tensorflow/tensorflow/go" "strconv" "strings" "sync" "fmt" ) type DCN struct { model *tf.SavedModel featureNames []string } var ( dcn *DCN dcnOnce sync.Once ) func GetDCNInstance(modelFile string, tags []string) *DCN { if dcn != nil { return dcn } dcnOnce.Do(func() { dcn = &DCN{} //LoadSavedModel时使用的go tensorflow版本不能低于tf.saved_model.builder.SavedModelBuilder时使用的tensorflow版本 if model, err := tf.LoadSavedModel(modelFile, tags, nil); err == nil { dcn.model = model dcn.featureNames = []string{"age", "work_year", "gender"} //第一次执行model.Session.Run很耗时,所以初始化后先预热一下 X := []float32{0.31,0.09,1.0} input := [][]float32{X} dcn.Predict(input) } else { fmt.Printf("read dcn model file %s failed: %v", modelFile, err) return } }) return dcn } //Predict 预测点击率。X的连续特征需要事先做好归一化,离散特征要转成index func (self DCN) Predict(X [][]float32) []float32 { if len(X[0]) != len(self.featureNames) { fmt.Printf("feature number of x is %d, but should be %d", len(X[0]), len(self.featureNames)) return nil } input_layer := make(map[tf.Output]*tf.Tensor) for i := 0; i < len(X[0]); i++ { //第i列 input := [][]float32{} for j := 0; j < len(X); j++ { //第j行 input = append(input, []float32{X[j][i]}) } tensor, _ := tf.NewTensor(input) // python版tensorflow/keras中定义的输入层input_layer out := self.model.Graph.Operation(self.featureNames[i]).Output(0) input_layer[out] = tensor } output_layer := []tf.Output{ //python版tensorflow/keras中定义的输出层output_layer self.model.Graph.Operation("output_layer/Sigmoid").Output(0), } if result, err := self.model.Session.Run(input_layer, output_layer, nil); err == nil { //不论是1条数据还是300条数据,执行该行代码只需要2毫秒 scores := result[0].Value().([][]float32) rect := make([]float32, len(scores)) for i, arr := range scores { rect[i] = arr[0] } return rect } else { fmt.Printf("predict failed: %v", err) return nil } } func main(){ dcn := rank.GetDCNInstance("dcnModel", []string{"myTag"}) X := []float32{0.31,0.09,1.0} input := [][]float32{X} scores := dcn.Predict(input) }
4.一个比较偏官方的英文教程
https://github.com/tensorflow/build/tree/master/golang_install_guide
本文来自博客园,作者:高性能golang,转载请注明原文链接:https://www.cnblogs.com/zhangchaoyang/articles/11363726.html