模型节点操作学习笔记(Appendix)实验1 -- Tflite int8 删除最后的Round节点 (持续更新)

              

                                                    Tflite                                                                  onnx

  •  尝试一:
    •   根据tflite模型,找到onnx模型(注意,是找到,不是生成,个人感觉生成太难了,我遇到了很多bug)
      •      代码;
        import onnx
        from onnx import helper
        
        def remove_nodes_and_update_output(model_path, nodes_to_remove, new_model_path):
            # 加载模型
            model = onnx.load(model_path)
            graph = model.graph
        
            # 将节点名字添加到一个集合中
            nodes_to_remove_set = set(nodes_to_remove)
        
            # 获取新的节点列表,排除要删除的节点
            new_nodes = [node for node in graph.node if node.name not in nodes_to_remove_set]
        
            # 更新图中的节点列表
            graph.ClearField('node')
            graph.node.extend(new_nodes)
        
            # 删除相关的输出
            new_outputs = [output for output in graph.output if output.name not in nodes_to_remove_set]
            graph.ClearField('output')
            graph.output.extend(new_outputs)
            
            # 添加新的输出节点
            # for new_output in new_outputs:
            #     new_output_tensor = helper.make_tensor_value_info(new_output.name, onnx.TensorProto.FLOAT, None)
            #     graph.output.append(new_output_tensor)
            
            # 保存新的模型
            onnx.save(model, new_model_path)
        
        def main():
            # 输入和输出的路径
            original_model_path = r"./../hand_landmark_sparse_Nx3x224x224.onnx"
            new_model_path = r"./../hand_landmark_sparse_Nx3x224x224_modified.onnx"
            
            # 加载模型
            model = onnx.load(original_model_path)
            graph = model.graph
        
            # 记录要删除节点的名字和它们的前一个节点的输出
            nodes_to_remove = []
            new_output_names = []
        
            # 找到要删除的输出节点
            for output in graph.output: # 找到要删除的输出节点
                if output.name == 'lefthand_0_or_righthand_1':
                    nodes_to_remove.append(output.name)
            
            for node in graph.node:
                if node.name == "Round_0" or node.output[0] == 'lefthand_0_or_righthand_1': 
                    nodes_to_remove.append(node.name)
                    new_output_names.append(node.input[0])
        
            # 确保新输出的名字唯一
            new_output_names = list(set(new_output_names))
        
            # 删除指定的节点并保存新的模型
            remove_nodes_and_update_output(original_model_path, nodes_to_remove, new_model_path)
        
            # 再次加载模型,添加新的输出
            model = onnx.load(new_model_path)
            graph = model.graph
        
            for new_output_name in new_output_names:
                new_output_tensor = helper.make_tensor_value_info(new_output_name, onnx.TensorProto.FLOAT, None)
                graph.output.append(new_output_tensor)
        
            # 保存最终的模型
            onnx.save(model, new_model_path)
        
        if __name__ == "__main__":
            main()
        

          

      •      结果:(接下来继续验证tflite)

         也可以参考另外一个博客:ONNX删除节点示例(Deeplabv3plus)-CSDN博客 觉得这个也可以。

      • 其他
posted @ 2024-05-30 20:43  张幼安  阅读(4)  评论(0编辑  收藏  举报