Java 版本tensorflow模型推理实现(基于bert命名实体、基于transform文本分类)
最近在做文本分类任务,由于在实际工程中需要用服务对外提供功能,故采用java调用pb模型完成推理,特将过程记录如下:
1. transform文本分类
package com.techwolf.transformer; import com.alibaba.fastjson.*; import com.alibaba.fastjson.parser.Feature; import org.tensorflow.Graph; import org.tensorflow.Session; import org.tensorflow.Tensor; //import com.alibaba.fastjson.JSONPObject; //import org.json.JSONObject; import java.io.*; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.*; public class JobPredict { private static String jsonPath = "src/main/resources/resource.json"; private static String modelPath = "src/main/resources/model.pb"; private static Map<String, Object> positionToFeature = new HashMap<String, Object>(); private static Map<String, Object> jobMapping = new HashMap<String, Object>(); private static Map<String, Object> mergeMapping = new HashMap<String, Object>(); private static Map<String, Object> featureToId = new HashMap<String, Object>(); private static Map<String, Object> idToCode = new HashMap<String, Object>(); private static Map<String, Object> codeToLabel = new HashMap<String, Object>(); public static String readJsonFile(String fileName) throws FileNotFoundException { String jsonStr = ""; try { File jsonFile = new File(fileName); FileReader fileReader = new FileReader(jsonFile); Reader reader = new InputStreamReader(new FileInputStream(jsonFile), "utf-8"); int ch = 0; StringBuffer sb = new StringBuffer(); while ((ch = reader.read()) != -1) { sb.append((char) ch); } fileReader.close(); reader.close(); jsonStr = sb.toString(); return jsonStr; } catch (IOException e) { e.printStackTrace(); return null; } } private static Map<String, Object> jsonTOMap(JSONObject jsobj) { Map<String, Object> data = new HashMap<String, Object>(); Iterator it = jsobj.entrySet().iterator(); while (it.hasNext()) { Map.Entry<String, Object> entry = (Map.Entry<String, Object>) it.next(); data.put(entry.getKey(), entry.getValue()); } return data; } private static void getConfig() throws FileNotFoundException { String jsonStr = readJsonFile(jsonPath); JSONObject obj = JSON.parseObject(jsonStr); positionToFeature = jsonTOMap(obj.getJSONObject("position2feature")); featureToId = jsonTOMap(obj.getJSONObject("feature2id")); jobMapping = jsonTOMap(obj.getJSONObject("job_mapping")); mergeMapping = jsonTOMap(obj.getJSONObject("merge_mapping")); idToCode = jsonTOMap(obj.getJSONObject("id2position")); codeToLabel = jsonTOMap(obj.getJSONObject("position_mapping")); System.out.println("config data loaded!"); } public static String convert(String utfString) { StringBuilder sb = new StringBuilder(); int i = -1; int pos = 0; int iint = 0; while ((i = utfString.indexOf("\\u", pos)) != -1) { String sd = utfString.substring(pos, i); sb.append(sd); iint = i + 5; if (iint < utfString.length()) { pos = i + 6; sb.append((char) Integer.parseInt(utfString.substring(i + 2, i + 6), 16)); } } String endStr = utfString.substring(iint + 1, utfString.length()); return sb + "" + endStr; } private static Map<String, List> getCodeAndScore(JSONArray jsonArray) throws FileNotFoundException { List<Integer> codes = new ArrayList<Integer>(); List<Float> scores = new ArrayList<Float>(); Integer codeFlag = -1; float scoreFlag = (float) .0; for (int i = 0; i < jsonArray.size(); i++) { JSONObject skillsItem = (JSONObject) jsonArray.get(i); String code = (skillsItem.get("code")).toString(); Float score = Float.parseFloat((String) skillsItem.get("score")); boolean isReplace = mergeMapping.containsKey(code); if (isReplace) { code = (mergeMapping.get(code)).toString(); System.out.println("replace id:" + code); } String position = (jobMapping.get(code)).toString(); Integer featSeq = (Integer) positionToFeature.get(position); if (featSeq == null) { codes.add((Integer) featureToId.get(codeFlag.toString())); scores.add(scoreFlag); } else { Integer x = (Integer) featureToId.get(featSeq.toString()); codes.add((Integer) featureToId.get(featSeq.toString())); scores.add(score); } } if (jsonArray.size() < 3) { for(int i=0; i< (3-jsonArray.size()); i++) { codes.add((Integer) featureToId.get(codeFlag.toString())); scores.add(scoreFlag); } } Map<String, List> result = new HashMap<String, List>(); result.put("codes", codes); result.put("scores", scores); return result; } private static byte[] readAllByteOrExit(Path path){ try{ return Files.readAllBytes(path); }catch (IOException e){ System.out.println("Failed to read[" + path + "]:" + e.getMessage()); System.exit(1); } return null; } private static Map<String, List> getDataContent(String testFile) throws FileNotFoundException { String jsonStr = readJsonFile(testFile); JSONObject obj = JSON.parseObject(jsonStr, Feature.OrderedField); JSONObject objNew = JSON.parseObject(obj.toJSONString(), Feature.OrderedField); ArrayList<List> sampleCode = new ArrayList<List>(); ArrayList<List> sampleScore = new ArrayList<List>(); Map<String, List> samples = new HashMap<String, List>(); for (String userId: objNew.keySet()) { ArrayList<List> codeList = new ArrayList<List>(); ArrayList<Double> scoresList = new ArrayList<Double>(); JSONObject itemTags = (JSONObject) ((JSONObject)((JSONObject)objNew.get(userId)).get("_source")).get("tags"); JSONArray skills = (JSONArray) itemTags.get("skills"); JSONArray title = (JSONArray) itemTags.get("title"); JSONArray desc = (JSONArray) itemTags.get("desc"); Map<String, List> skillsResult = getCodeAndScore(skills); Map<String, List> titleResult = getCodeAndScore(title); Map<String, List> descResult = getCodeAndScore(desc); codeList.addAll(skillsResult.get("codes")); codeList.addAll(titleResult.get("codes")); codeList.addAll(descResult.get("codes")); scoresList.addAll(skillsResult.get("scores")); scoresList.addAll(titleResult.get("scores")); scoresList.addAll(descResult.get("scores")); sampleCode.add(codeList); sampleScore.add(scoresList); } samples.put("sampleCode", sampleCode); samples.put("sampleScore", sampleScore); System.out.println("ok! sample feature created."); return samples; } public static int[] arraySort(float[] arr, boolean desc) { float temp; int index; int k = arr.length; int[] Index = new int[k]; for (int i = 0; i < k; i++) { Index[i] = i; } for (int i = 0; i < arr.length; i++) { for (int j = 0; j < arr.length - i - 1; j++) { if (desc) { if (arr[j] < arr[j + 1]) { temp = arr[j]; arr[j] = arr[j + 1]; arr[j + 1] = temp; index = Index[j]; Index[j] = Index[j + 1]; Index[j + 1] = index; } } else { if (arr[j] > arr[j + 1]) { temp = arr[j]; arr[j] = arr[j + 1]; arr[j + 1] = temp; index = Index[j]; Index[j] = Index[j + 1]; Index[j + 1] = index; } } } } return Index; } private static void featToTensor(float[][][] indexes, int[][] codes, float[][] scores, Map<String, List> data) { List<Integer> featCode = data.get("sampleCode"); List<Float> featScore = data.get("sampleScore"); int size = 9; for(int i=0; i < featCode.size(); i++) { Object eachCode = featCode.get(i); Object eachScore = featScore.get(i); float [][] positionResult = new float[size][]; for(int step=0; step < size; step++) { float[] positionVector = new float[size]; positionVector[step] = 1; positionResult[step] = positionVector; } indexes[i] = positionResult; Integer[] targetInter = ((List<Integer>)eachCode).toArray(new Integer[size]); int[] codeResult = Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray(); Float[] targetFloat = ((List<Float>)eachScore).toArray(new Float[size]); double[] scoreResult = Arrays.stream(targetFloat).mapToDouble(Double::valueOf).toArray(); float[] scoreFloat = new float[size]; for(int j=0; j < scoreResult.length; j++) { scoreFloat[j] = (float) scoreResult[j]; } System.arraycopy(codeResult,0,codes[i], 0, codeResult.length); System.arraycopy(scoreFloat,0,scores[i], 0, scoreResult.length); } } private static List<HashMap<String, Float>> modelInfer(Map<String, List> data) { int batchSize = data.get("sampleCode").size(); int padLength = 9; int returnNum = 5; int classNum = 868; float[][][] indexes = new float[batchSize][padLength][padLength]; int[][] codes = new int[batchSize][padLength]; float[][] scores = new float[batchSize][padLength]; float transKeepProb = (float) 1.0; float multiKeepProb = (float) 1.0; byte[] graphDef = readAllByteOrExit(Paths.get(modelPath)); Graph g = new Graph(); g.importGraphDef(graphDef); Session sess = new Session(g); featToTensor(indexes, codes, scores, data); Tensor tensorIndex = Tensor.create(indexes); Tensor tensorCode = Tensor.create(codes); Tensor tensorScore = Tensor.create(scores); Tensor tensorTransProb = Tensor.create(transKeepProb); Tensor tensorMultiProb = Tensor.create(multiKeepProb); Tensor tensorClassResult = sess.runner(). feed("input_x:0", tensorCode). feed("input_x_score:0", tensorScore). feed("embed_position:0", tensorIndex). feed("trans_keep_prob:0", tensorTransProb). feed("multi_keep_prob:0", tensorMultiProb). fetch("discriminator/softmax_score:0").run().get(0); float[][] result = (float[][]) tensorClassResult.copyTo(new float[batchSize][classNum]); List<HashMap<String, Float>> predictResult = new ArrayList(); for(int i=0; i<result.length; i++){ float[] resultVec = result[i]; int[] resultIndex = new int[classNum]; HashMap<String, Float> predictSample = new HashMap<String, Float>(); resultIndex = arraySort(resultVec, true); for(int s=0; s < returnNum; s++) { String sampleCode = Integer.toString(resultIndex[s]); String label = (String) codeToLabel.get(Integer.toString((Integer) idToCode.get(sampleCode))); predictSample.put(label, resultVec[s]); } predictResult.add(predictSample); } tensorClassResult.close(); tensorMultiProb.close(); tensorTransProb.close(); tensorScore.close(); tensorCode.close(); tensorIndex.close(); return predictResult; } public static void main (String[]args) throws IOException { String testFile = "src/main/data/predict_data.json"; getConfig(); Map<String, List> samples = getDataContent(testFile); List<HashMap<String, Float>> result = modelInfer(samples); System.out.println(result); } }
2. 基于bert的ner
package com.techwolf.bert; import org.tensorflow.Graph; import org.tensorflow.Session; import org.tensorflow.Tensor; import java.io.BufferedReader; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStreamReader; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.*; public class BertNerPredict { private static String vocabPath = "src/main/resources/vocab.txt"; private static Map<String, Integer> word2id = new HashMap<String, Integer>(); static { try { BufferedReader buffer = null; buffer = new BufferedReader(new InputStreamReader(new FileInputStream(vocabPath))); int i = 0; String line = buffer.readLine().trim(); while (line!=null){ word2id.put(line, i++); line = buffer.readLine().trim(); } buffer.close(); }catch (Exception e){ } // System.out.println("word2id size is:"+word2id.size()); } private static byte[] readAllByteOrExit(Path path){ try{ return Files.readAllBytes(path); }catch (IOException e){ System.out.println("Failed to read[" + path + "]:" + e.getMessage()); System.exit(1); } return null; } public static void getTextToId(int[][] inputIds, int[][] inputMask, String[] text){ for(int i=0; i<text.length; i++){ char[] chs = text[i].trim().toLowerCase().toCharArray(); List<Integer> list = new ArrayList<>(); List<Integer> mask = new ArrayList<>(); list.add(word2id.get("[CLS]")); mask.add(1); for(int j=0; j<chs.length; j++){ String element = Character.toString(chs[j]); if(word2id.containsKey(element)){ list.add(word2id.get(element)); mask.add(1); } } list.add(word2id.get("[SEP]")); mask.add(1); int size = list.size(); Integer[] targetInter = list.toArray(new Integer[size]); int[] idResult = Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray(); Integer[] maskInter = mask.toArray(new Integer[size]); int[] maskResult = Arrays.stream(maskInter).mapToInt(Integer::valueOf).toArray(); System.arraycopy(idResult,0,inputIds[i], 0, idResult.length); System.arraycopy(maskResult,0,inputMask[i], 0, maskResult.length); } } public static void main(String[] args) { String[] query = new String[]{"中华人民共和国", "新疆大学"}; String resourceDir = "src/main/resources"; String modelName = "model.pb"; int batchSize = query.length; int padLength = 25; int[][] indexes = new int[batchSize][padLength]; int[][] mask = new int[batchSize][padLength]; byte[] graphDef = readAllByteOrExit(Paths.get(resourceDir, modelName)); Graph g = new Graph(); g.importGraphDef(graphDef); Session sess = new Session(g); if (query.length>0){ System.out.println("Ok! Start predicting...\n"); }else { System.exit(0); } getTextToId(indexes, mask, query); Tensor tensorInputIds = Tensor.create(indexes); Tensor tensorMask = Tensor.create(mask); Tensor tensorSeqResult = sess.runner().feed("input_ids:0", tensorInputIds). feed("input_mask:0", tensorMask).fetch("viterbi/ReverseSequence_1:0").run().get(0); Tensor tensorScoreResult = sess.runner().feed("input_ids:0", tensorInputIds). feed("input_mask:0", tensorMask).fetch("viterbi/Max:0").run().get(0); int[][] sequenceId = (int[][]) tensorSeqResult.copyTo(new int[batchSize][padLength]); float[] sequenceScore = (float[]) tensorScoreResult.copyTo(new float[batchSize]); for(int i=0; i<sequenceId.length; i++){ System.out.println("query: "+query[i]); System.out.println("sequence result: "+ Arrays.toString(sequenceId[i])); System.out.println("sequence score: "+ sequenceScore[i]+'\n'); } tensorScoreResult.close(); tensorSeqResult.close(); tensorMask.close(); tensorInputIds.close(); } }
时刻记着自己要成为什么样的人!