from __future__ import print_function
from __future__ import division
import tensorflow as tf
import numpy as np
import os
import argparse
def dense_to_one_hot(input_data, class_num):
data_num = input_data.shape[0]
index_offset = np.arange(data_num) * class_num
labels_one_hot = np.zeros((data_num, class_num))
labels_one_hot.flat[index_offset + input_data.ravel()] = 1
return labels_one_hot
def build_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--train_path', type=str, required=True)
parser.add_argument('--test_path', type=str, required=True)
parser.add_argument('--model_path', type=str, required=True)
parser.add_argument('--board_dir', type=str, required=True)
args = parser.parse_args()
return args
def variable_summaries(var):
with tf.name_scope('summaries'):
mean = tf.reduce_mean(var)
tf.summary.scalar('mean', mean)
with tf.name_scope('stddev'):
stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
tf.summary.scalar('stddev', stddev)
tf.summary.scalar('max', tf.reduce_max(var))
tf.summary.scalar('min', tf.reduce_min(var))
tf.summary.histogram('histogram', var)
p = build_parser()
if tf.gfile.Exists(p.board_dir):
tf.gfile.DeleteRecursively(p.board_dir)
tf.gfile.MakeDirs(p.board_dir)
origin_train = np.genfromtxt(p.train_path, delimiter=',')
data_train = origin_train[:, 0:2]
labels_train = origin_train[:, 2]
origin_test = np.genfromtxt(p.train_path, delimiter=',')
data_test = origin_train[:, 0:2]
labels_test = origin_train[:, 2]
learning_rate = 0.001
training_epochs = 5000
display_step = 1
n_features = 2
n_class = 2
x = tf.placeholder(tf.float32, [None, n_features], "input")
y = tf.placeholder(tf.float32, [None, n_class])
with tf.name_scope('W'):
W = tf.Variable(tf.zeros([n_features, n_class]), name="w")
variable_summaries(W)
with tf.name_scope('b'):
b = tf.Variable(tf.zeros([n_class]), name="b")
variable_summaries(b)
scores = tf.nn.xw_plus_b(x, W, b, name='scores')
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=scores, labels=y))
tf.summary.scalar('cross_entropy', cost)
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(os.path.join(p.board_dir, 'train'))
test_writer = tf.summary.FileWriter(os.path.join(p.board_dir, 'test'))
saver = tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for epoch in range(training_epochs):
_, c = sess.run([optimizer, cost],
feed_dict={x: data_train,
y: dense_to_one_hot(labels_train.astype(int), 2)})
if epoch % 100 == 0:
summary, c = sess.run([merged, cost],
feed_dict={x: data_train,
y: dense_to_one_hot(labels_train.astype(int), 2)})
train_writer.add_summary(summary, epoch)
test_writer.add_summary(summary, epoch)
saver.save(sess, p.model_path)
train_writer.close()
test_writer.close()
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?
· 【译】Visual Studio 中新的强大生产力特性
· 【设计模式】告别冗长if-else语句:使用策略模式优化代码结构
· 10年+ .NET Coder 心语 ── 封装的思维:从隐藏、稳定开始理解其本质意义