逻辑回归模型

Sigmoid函数

import numpy as np
import matplotlib.pyplot as plt

def sigmoid(x):
    y = 1.0 / (1.0 + np.exp(-x))
    return y

plot_x = np.linspace(-10, 10, 100)
plot_y = sigmoid(plot_x)
plt.plot(plot_x, plot_y)
plt.title('Sigmoid')
plt.show()

梯度下降法及学习率分析

import numpy as np
import matplotlib.pyplot as plt

def J(theta): # 损失函数
    try:
        return (theta-2.5)**2 -1
    except:
        return float('inf')

def dJ(theta): # 损失函数的导数
    return 2 * (theta - 2.5)

def iteration(eta):
    theta = 0.0     # 初始点
    theta_history = [theta]
    epsilon=1e-8

    i_iter= 0
    n_iters = 450
    while i_iter < n_iters:
        gradient = dJ(theta)
        last_theta = theta
        theta = theta - eta * gradient
        i_iter += 1
        theta_history.append(theta)
        if (abs(J(theta) - J(last_theta)) < epsilon):
            break   # 当两个theta值非常接近的时候,终止循环
    return theta_history


if __name__ == '__main__':
    plot_x = np.linspace(-1, 6, 141)
    eta = [0.01, 0.8, 1.1]
    title = ['0.01','0.8','1.1']

    for eta, title in zip(eta, title):
        theta_history = iteration(eta)
        plt.plot(plot_x, J(plot_x), color='r')
        plt.plot(np.array(theta_history), J(np.array(theta_history)), color='b', marker='x')
        # 设置名称
        plt.title(title)
        plt.xlabel('theta', fontproperties='simHei', fontsize=15)
        plt.ylabel('loss function', fontproperties='simHei', fontsize=15)
        plt.savefig('{}.png'.format(title))
        plt.clf()
        print('When eta={}, total steps of gradient descent is {}'.format(eta, len(theta_history))) 
        # plt.show()

逻辑回归的实现

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

'''逻辑回归模型'''
class LogisticRegression:
    def __init__(self):
        self.coef_ = None   #维度
        self.intercept_ = None #截距
        self._theta = None # 截距+权重
 
    # sigmoid函数
    def _sigmoid(self,x):
        y = 1.0 / (1.0 + np.exp(-x))
        return y

    '''
    X_train: 训练集数据输入x
    y_train: 数据x对应的标签输出y
    eta:     学习率
    n_iters: 迭代总数
    '''
    def fit(self,X_train,y_train,eta=0.1,n_iters=1e4):
        assert X_train.shape[0] == y_train.shape[0], '训练数据集的长度需要与标签长度保持一致'
    
        # 损失函数
        def J(theta,X_b,y): 
            p_predcit = self._sigmoid(X_b.dot(theta)) 
            try:
                return - np.sum(y*np.log(p_predcit) + (1-y)*np.log(1-p_predcit)) / len(y)
            except:
                return float('inf')
    
        # 求sigmoid梯度的导数
        def dJ(theta,X_b,y):
            x = self._sigmoid(X_b.dot(theta))   
            return X_b.T.dot(x-y)/len(X_b)  
    
        # 梯度下降
        def gradient_descent(X_b,y,initial_theta,eta,n_iters=1e4,epsilon=1e-8):
            theta = initial_theta  
            i_iter = 0
            while i_iter < n_iters:
                gradient = dJ(theta,X_b,y)  
                last_theta = theta
                theta = theta - eta * gradient 
                i_iter += 1
                if (abs(J(theta,X_b,y) - J(last_theta,X_b,y)) < epsilon):
                    break
            return theta

        X_b = np.hstack([np.ones((X_train.shape[0],1)),X_train])
        initial_theta = np.zeros(X_b.shape[1])
        self._theta = gradient_descent(X_b,y_train,initial_theta,eta,n_iters)
        self.intercept_ = self._theta[0] 
        self.coef_ = self._theta[1:] 
        return self

    # 预测概率
    def predict_proba(self,X_predict):
        X_b = np.hstack([np.ones((X_predict.shape[0], 1)), X_predict]) 
        return self._sigmoid(X_b.dot(self._theta))
    
    # 预测归类
    def predict(self,X_predict):
        proba = self.predict_proba(X_predict)
        return np.array(proba > 0.5,dtype='int')
    
    # 数据集的散点图
    def scatter_data(data):
        Data = pd.read_csv('ex2data1.txt', sep=',', header=None, names=['test 1', 'test 2', 'Admitted'])
        positive = Data[Data['Admitted'] == 1] 
        negative = Data[Data['Admitted'] == 0] 
        plt.scatter(positive['test 1'], positive['test 2'], s=30,
                                        c='b', marker='o', label='Admitted')
        plt.scatter(negative['test 1'], negative['test 2'], s=30, 
                                        c='r', marker='x', label='Not Admitted')
        plt.xlabel('test 1 Score')
        plt.ylabel('test 2 Score')
        #plt.show()
        return Data

if __name__ == '__main__':
    LR = LogisticRegression()
    Data = LR.scatter_data()
    cols = Data.shape[1]
    LR.fit = (Data.values[:,0:cols-1], Data.values[:,cols-1:])
    print(LR.predict(Data.values[:,0:cols-1]))
**ex2data1.txt**文件数据内容
34.62365962451697,78.0246928153624,0
30.28671076822607,43.89499752400101,0
35.84740876993872,72.90219802708364,0
60.18259938620976,86.30855209546826,1
79.0327360507101,75.3443764369103,1
45.08327747668339,56.3163717815305,0
61.10666453684766,96.51142588489624,1
75.02474556738889,46.55401354116538,1
76.09878670226257,87.42056971926803,1
84.43281996120035,43.53339331072109,1
95.86155507093572,38.22527805795094,0
75.01365838958247,30.60326323428011,0
82.30705337399482,76.48196330235604,1
69.36458875970939,97.71869196188608,1
39.53833914367223,76.03681085115882,0
53.9710521485623,89.20735013750205,1
69.07014406283025,52.74046973016765,1
67.94685547711617,46.67857410673128,0
70.66150955499435,92.92713789364831,1
76.97878372747498,47.57596364975532,1
67.37202754570876,42.83843832029179,0
89.67677575072079,65.79936592745237,1
50.534788289883,48.85581152764205,0
34.21206097786789,44.20952859866288,0
77.9240914545704,68.9723599933059,1
62.27101367004632,69.95445795447587,1
80.1901807509566,44.82162893218353,1
93.114388797442,38.80067033713209,0
61.83020602312595,50.25610789244621,0
38.78580379679423,64.99568095539578,0
61.379289447425,72.80788731317097,1
85.40451939411645,57.05198397627122,1
52.10797973193984,63.12762376881715,0
52.04540476831827,69.43286012045222,1
40.23689373545111,71.16774802184875,0
54.63510555424817,52.21388588061123,0
33.91550010906887,98.86943574220611,0
64.17698887494485,80.90806058670817,1
74.78925295941542,41.57341522824434,0
34.1836400264419,75.2377203360134,0
83.90239366249155,56.30804621605327,1
51.54772026906181,46.85629026349976,0
94.44336776917852,65.56892160559052,1
82.36875375713919,40.61825515970618,0
51.04775177128865,45.82270145776001,0
62.22267576120188,52.06099194836679,0
77.19303492601364,70.45820000180959,1
97.77159928000232,86.7278223300282,1
62.07306379667647,96.76882412413983,1
91.56497449807442,88.69629254546599,1
79.94481794066932,74.16311935043758,1
99.2725269292572,60.99903099844988,1
90.54671411399852,43.39060180650027,1
34.52451385320009,60.39634245837173,0
50.2864961189907,49.80453881323059,0
49.58667721632031,59.80895099453265,0
97.64563396007767,68.86157272420604,1
32.57720016809309,95.59854761387875,0
74.24869136721598,69.82457122657193,1
71.79646205863379,78.45356224515052,1
75.3956114656803,85.75993667331619,1
35.28611281526193,47.02051394723416,0
56.25381749711624,39.26147251058019,0
30.05882244669796,49.59297386723685,0
44.66826172480893,66.45008614558913,0
66.56089447242954,41.09209807936973,0
40.45755098375164,97.53518548909936,1
49.07256321908844,51.88321182073966,0
80.27957401466998,92.11606081344084,1
66.74671856944039,60.99139402740988,1
32.72283304060323,43.30717306430063,0
64.0393204150601,78.03168802018232,1
72.34649422579923,96.22759296761404,1
60.45788573918959,73.09499809758037,1
58.84095621726802,75.85844831279042,1
99.82785779692128,72.36925193383885,1
47.26426910848174,88.47586499559782,1
50.45815980285988,75.80985952982456,1
60.45555629271532,42.50840943572217,0
82.22666157785568,42.71987853716458,0
88.9138964166533,69.80378889835472,1
94.83450672430196,45.69430680250754,1
67.31925746917527,66.58935317747915,1
57.23870631569862,59.51428198012956,1
80.36675600171273,90.96014789746954,1
68.46852178591112,85.59430710452014,1
42.0754545384731,78.84478600148043,0
75.47770200533905,90.42453899753964,1
78.63542434898018,96.64742716885644,1
52.34800398794107,60.76950525602592,0
94.09433112516793,77.15910509073893,1
90.44855097096364,87.50879176484702,1
55.48216114069585,35.57070347228866,0
74.49269241843041,84.84513684930135,1
89.84580670720979,45.35828361091658,1
83.48916274498238,48.38028579728175,1
42.2617008099817,87.10385094025457,1
99.31500880510394,68.77540947206617,1
55.34001756003703,64.9319380069486,1
74.77589300092767,89.52981289513276,1
posted @   YI颗白菜  阅读(23)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?
· 如何调用 DeepSeek 的自然语言处理 API 接口并集成到在线客服系统
· 【译】Visual Studio 中新的强大生产力特性
· 2025年我用 Compose 写了一个 Todo App
点击右上角即可分享
微信分享提示