JadenFK
哪有什么岁月静好,只是有人替我们负重前行! http://JadenFK.github.io
随笔 - 50,  文章 - 1,  评论 - 9,  阅读 - 70280

本篇博客代码来自于《动手学深度学习》pytorch版,也是代码较多,解释较少的一篇。不过好多方法在我以前的博客都有提,所以这次没提。还有一个原因是,这篇博客的代码,只要好好看看肯定能看懂(前提是python语法大概了解),这是我不加很多解释的重要原因。

K折交叉验证实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def get_k_fold_data(k, i, X, y):
    # 返回第i折交叉验证时所需要的训练和验证数据,分开放,X_train为训练数据,X_valid为验证数据
    assert k > 1
    fold_size = X.shape[0] // k  # 双斜杠表示除完后再向下取整
    X_train, y_train = None, None
    for j in range(k):
        idx = slice(j * fold_size, (j + 1) * fold_size)  #slice(start,end,step)切片函数
        X_part, y_part = X[idx, :], y[idx]
        if j == i:
            X_valid, y_valid = X_part, y_part
        elif X_train is None:
            X_train, y_train = X_part, y_part
        else:
            X_train = torch.cat((X_train, X_part), dim=0) #dim=0增加行数,竖着连接
            y_train = torch.cat((y_train, y_part), dim=0)
    return X_train, y_train, X_valid, y_valid
 
def k_fold(k, X_train, y_train, num_epochs,learning_rate, weight_decay, batch_size):
    train_l_sum, valid_l_sum = 0, 0
    for i in range(k):
        data = get_k_fold_data(k, i, X_train, y_train) # 获取k折交叉验证的训练和验证数据
        net = get_net(X_train.shape[1])  #get_net在这是一个基本的线性回归模型,方法实现见附录1
        train_ls, valid_ls = train(net, *data, num_epochs, learning_rate,
                                   weight_decay, batch_size)  #train方法见后面附录2
        train_l_sum += train_ls[-1]
        valid_l_sum += valid_ls[-1]
        if i == 0:
            d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'rmse',
                         range(1, num_epochs + 1), valid_ls,
                         ['train', 'valid'])   #画图,且是对y求对数了,x未变。方法实现见附录3
        print('fold %d, train rmse %f, valid rmse %f' % (i, train_ls[-1], valid_ls[-1]))
    return train_l_sum / k, valid_l_sum / k

 *args:表示接受任意长度的参数,然后存放入一个元组中;如def fun(*args) print(args),‘fruit','animal','human'作为参数传进去,输出(‘fruit','animal','human')

**kwargs:表示接受任意长的参数,然后存放入一个字典中;如

1
2
3
def fun(**kwargs):  
    for key, value in kwargs.items():
        print("%s:%s" % (key,value)

fun(a=1,b=2,c=3)会输出 a=1 b=2 c=3

附录1

1
2
3
4
5
6
7
loss = torch.nn.MSELoss()
 
def get_net(feature_num):
    net = nn.Linear(feature_num, 1)
    for param in net.parameters():
        nn.init.normal_(param, mean=0, std=0.01)
    return net

附录2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def train(net, train_features, train_labels, test_features, test_labels, num_epochs, learning_rate,weight_decay, batch_size):
    train_ls, test_ls = [], []
    dataset = torch.utils.data.TensorDataset(train_features, train_labels)
    train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True) #TensorDataset和DataLoader的使用请查看我以前的博客
     
    #这里使用了Adam优化算法
    optimizer = torch.optim.Adam(params=net.parameters(), lr= learning_rate, weight_decay=weight_decay)
    net = net.float()
    for epoch in range(num_epochs):
        for X, y in train_iter:
            l = loss(net(X.float()), y.float())
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
        train_ls.append(log_rmse(net, train_features, train_labels))
        if test_labels is not None:
            test_ls.append(log_rmse(net, test_features, test_labels))
    return train_ls, test_ls

 附录3

1
2
3
4
5
6
7
8
def semilogy(x_vals, y_vals, x_label, y_label, x2_vals=None, y2_vals=None, legend=None, figsize=(3.5, 2.5)):
    set_figsize(figsize)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.semilogy(x_vals, y_vals)
    if x2_vals and y2_vals:
        plt.semilogy(x2_vals, y2_vals, linestyle=':')
        plt.legend(legend)

 注:由于最近有其他任务,所以此博客写的匆忙,等我有时间后会丰富,也可能加详细解释。

posted on   郭心全  阅读(5010)  评论(7编辑  收藏  举报
编辑推荐:
· AI与.NET技术实操系列:基于图像分类模型对图像进行分类
· go语言实现终端里的倒计时
· 如何编写易于单元测试的代码
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
阅读排行:
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 25岁的心里话
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列01:轻松3步本地部署deepseek,普通电脑可用
· 按钮权限的设计及实现

< 2025年3月 >
23 24 25 26 27 28 1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28 29
30 31 1 2 3 4 5
是我,还记得我吗? 快来逗我玩吧!
点击右上角即可分享
微信分享提示