台大机器学习——感知机

台大机器学习笔记——感知机

最近发现自己机器学习学的不够系统,很多知识点都存在欠缺,所有准备找一个稍微有一点难度的课程仔细学习一下。台大的机器学习比较有深度而且讲的深入浅出,所以我准备一边学习台大机器学习的课程一边做一些读书笔记。我也决定使用IPython-notebook来进行记录笔记。

课程前言

What is machine learning?

learning:acquiring skill with experience accumulated from observation

%%dot 
digraph G {
       rankdir=LR; observations -> learing -> skill
    }

Machine learning:acquiring skill with experience accumulated/computed from data

%%dot 
digraph G {
        rankdir=LR;
        data -> ML -> skill;
    }

skill: improve some proformance measure(e.g. prediction accuracy)

%%dot 
digraph G {
        rankdir=LR;
        data -> ML -> "improved performance measure";
    }

Why use machine learning

  • when human cannot program the system manually
  • when human cannot 'define the solution' easily
  • when needing rapid decisions that humans cannot do (high-frequency trading)
  • when needing to service a massive scale user(大量的个人化服务)

Key Essence of Machine Learning

  • exists some 'underlying pattern' to be learned ('performance measure' can be improved)
  • but no programmable (easy) definition
  • somehow there is data aboun pattern

Formalize the learning Problem

  • input: $ x\in X$ (customer appilcation)
  • output: $ y\in Y$ (good/bad after approving credit card)
  • unkonwn pattern to be learned \(\Leftrightarrow\) target function: $ f: X \to Y $(ideal credit approval formula)
  • data \(\Leftrightarrow\) training examples: $ D={(x_1, y_1), (x_2, y_2), ... , (x_N, y_N)} $
  • hypothesis \(\Leftrightarrow\) skill with hopefully good performance: $g : X \to Y $('learned' formula to be used)

机器学习的一般流程

%%dot
digraph G {

rankdir=LR;a -> b;
b->c;
c -> d;
e ->c;
a [shape=box,sides=4,skew=.4,color=lightblue,style=filled,label="想要得到的目标函数\n f: x->y"];
b [shape=box,sides=4,skew=.4,label="训练样本\n D:(x_1,y_1),...,(x_n,y_n)"]
c [label="learning algorithm A"];
d [shape=box, label="最终得到的目标函数\n g"];
e [label="假设空间\n H"];
}

Perceptron Learning Algorithm (感知机)

感知机的一般工作流程如下:
我们要找到一个w使得 \(y = w^Tx\)能够恰好分割我们的数据集。
\(w_0\)初始化为0

For t = 0,1,...

  • \(w_t\)下计算找到错误点\((x_{n(t)}, y_{n(t)})\)\(sign(w_t^Tx_{n(t)})\not=y_{n(t)}\)
  • 尝试去修正错误,\(w\)的更新方法为 \(w_{t+1}\gets w_t + y_{n(t)}x_{n(t)}\)

该方法精髓就是知错能改,—— A fault confessed is half redressed!

Guarantee of PLA

首先PLA要终止即必须数据集线性可分,那么是否数据集线性可分PLA一定会终止呢?

\(w_f\)是能够划分数据集的完美曲线,所以有:

\(y_{n(t)}w_f^Tx_{n(t)}\ge\min\limits_{n}y_{n(t)}w_f^Tx_{n(t)}>0\)

我们可以推导出\(w_f^Tw_t\)随着\((x_{n(t)},y_{n(t)})\)的更新,会越来越大。

\[w_f^Tw_{t+1} = w_f(w_t + y_{n(t)}x_{n(t)}) \ge w_f^Tw_t + \min\limits_{n}y_nw_f^Tx_n > w_f^Tw_t + 0 \]

\(y_{n(t)}w_f^Tx_{n(t)}\)在不断变大意味着两点

  1. 这两个向量越来接近了
  2. w的长度在变大

下面我们要证明w的长度是有上界的,我们有:

\[\begin{align} ||w_{t+1}||^2 &= || w_t + y_{n(t)}x_{n(t)}||^2 \\ &= ||w_t||^2 + 2y_{n(t)}w_t^Tx_{n(t)} + ||y_{n(t)}x_{n(t)}||^2 \\ &\le ||w_t||^2 + 0 + ||y_{n(t)}x_{n(t)}||^2 \\ &\le ||w_t||^2 + \max\limits_{n}||y_nx_n||^2 \end{align} \]

即有:\(||w_t||^2 \le T\max\limits_{n}||y_nx_n||^2\),对于一个固定的训练集来说后者是一个固定值。这说明随着迭代\(w_f,w_t\)越来越接近了。

而且我们可以证明更新次数T是有上界的:

\[\frac{w_f^T}{||w_f||}\frac{w_T}{||w_T||}\ge\sqrt{T}\cdot constant \]

证明如下:

\[\begin{align} w_f^Tw_t &= w_f^T(w_{t-1} + y_{n-1}x_{n-1}) \\ &\ge w_f^Tw_{t-1} + \min\limits_{n}y_nw_f^Tx_n \\ &\ge w_0 + T*\min\limits_{n}y_nw_f^Tx_n \\ &\ge T*\min\limits_{n}y_nw_f^Tx_n \end{align} \]

\(||w_t||^2\)而言,我们有

\[ \begin{align} ||w_t||^2 &= ||w_{t-1} + y_{n(t-1)}x_{n(t-1)}||^2 \\ & \le T\max\limits_{n}||x_n||^2 \end{align} \]

根据以上可以得出

\[\frac{w_f^T}{||w_f||}\frac{w_T}{||w_T||}\ge\sqrt{T}\cdot \frac{\min\limits_{n}y_nw_f^Tx_n}{||w_f||\sqrt{\max\limits_{n}||x_n||^2}} \]

我们有

\[\frac{w_f^T}{||w_f||}\frac{w_T}{||w_T||}\le1 \]

最后我们可以得出

\[T\le\frac{\max\limits_{n}||x_n||^2\cdot ||w_f||^2}{{\min\limits_{n}}^2y_nw_f^Tx_n} \]

PLA的优缺点

  • 优点是简单,快速,适合推广到高维
  • 缺点是数据集必须线性可分PLA才能停止,而且我们无法确定算法的运行时间

Pocket Algorithm

现实中我们很难得到完全线性可分的数据集,数据集可能会有噪声,为了应对噪声,我们采用新的权值更新策略

权值初始化为\(\hat{w}\),并将其保存,相当于把最好的w放入到口袋中。

For t = 0,1,...

  • \(w_t\)随机找到错误点\((x_{n(t)}, y_{n(t)})\)\(sign(w_t^Tx_{n(t)})\not=y_{n(t)}\)
  • 尝试去修正错误,\(w\)的更新方法为 \(w_{t+1}\gets w_t + y_{n(t)}x_{n(t)}\)
  • 如果\(w_{t+1}\)这条线产生的错误小于\(\hat{w}\),则用\(w_{t+1}\)代替\(\hat{w}\).
posted @ 2016-10-31 00:44  liujshi  阅读(774)  评论(0编辑  收藏  举报
MathJax.Hub.Config({ jax: ["input/TeX","output/HTML-CSS"], displayAlign: "left" });