机器学习算法原理实现——EM算法
【EM算法简介】
EM算法,全称为期望最大化算法(Expectation-Maximization Algorithm),是一种迭代优化算法,主要用于含有隐变量的概率模型参数的估计。EM算法的基本思想是:如果给定模型的参数,那么可以根据模型计算出隐变量的期望值;反过来,如果给定隐变量的值,那么可以通过最大化似然函数来估计模型的参数。EM算法就是通过交替进行这两步来找到参数的最大似然估计。
EM算法的基本步骤如下:
1. 初始化模型参数
2. E步:计算隐变量的期望值
3. M步:最大化似然函数,更新模型参数
4. 重复步骤2和3,直到模型参数收敛
【EM算法举例】
K-means算法可以被看作是一种特殊的EM算法。在K-means算法中,我们试图找到一种方式将数据点分配到K个集群中,使得每个数据点到其所在集群中心的距离之和最小。
如果我们将集群分配看作是隐变量,那么K-means算法就可以看作是EM算法:
1. E步:期望步骤。给定当前的集群中心(模型参数),我们可以计算每个数据点最近的集群中心,也就是将每个数据点分配到一个集群中。这个步骤就是计算隐变量的期望值。
2. M步:最大化步骤。给定当前的集群分配(隐变量的值),我们可以计算新的集群中心,也就是每个集群中所有数据点的均值。这个步骤就是最大化似然函数,更新模型参数。
通过交替进行E步和M步,K-means算法可以找到一种集群分配和集群中心,使得每个数据点到其所在集群中心的距离之和最小。这就是K-means算法使用EM算法的地方。
【再举一个例子】
见:https://zhuanlan.zhihu.com/p/78311644 写得非常好,关键摘录如下:
【python编程实现】
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | import math import random def coin_em(rolls, theta_A = None , theta_B = None , maxiter = 10000 , tol = 1e - 6 ): # 初始化参数 theta_A = theta_A or random.random() theta_B = theta_B or random.random() loglike_old = 0 for i in range (maxiter): # E步 heads_A, tails_A, heads_B, tails_B = e_step(rolls, theta_A, theta_B) # M步 theta_A, theta_B = m_step(heads_A, tails_A, heads_B, tails_B) # 计算对数似然 loglike_new = loglikelihood(rolls, theta_A, theta_B) # 检查收敛 if abs (loglike_new - loglike_old) < tol: break else : loglike_old = loglike_new return theta_A, theta_B def e_step(rolls, theta_A, theta_B): heads_A, tails_A, heads_B, tails_B = 0 , 0 , 0 , 0 for trial in rolls: likelihood_A = likelihood(trial, theta_A) likelihood_B = likelihood(trial, theta_B) p_A = likelihood_A / (likelihood_A + likelihood_B) p_B = 1 - p_A heads_A + = p_A * trial.count( "H" ) tails_A + = p_A * trial.count( "T" ) heads_B + = p_B * trial.count( "H" ) tails_B + = p_B * trial.count( "T" ) return heads_A, tails_A, heads_B, tails_B def m_step(heads_A, tails_A, heads_B, tails_B): theta_A = heads_A / (heads_A + tails_A) theta_B = heads_B / (heads_B + tails_B) return theta_A, theta_B def likelihood(roll, theta): numHeads = roll.count( "H" ) flips = len (roll) return (theta * * numHeads) * (( 1 - theta) * * (flips - numHeads)) def loglikelihood(rolls, theta_A, theta_B): total = 0 for roll in rolls: heads = roll.count( "H" ) tails = roll.count( "T" ) total + = math.log( 0.5 * likelihood(roll, theta_A) + 0.5 * likelihood(roll, theta_B)) return total # 测试 rolls = [ "HTTTHHTHTH" , "HHHHTHHHHH" , "HTHHHHHTHH" , "HTHTTTHHTT" , "THHHTHHHTH" ] print (coin_em(rolls)) |
输出:
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」
2022-10-01 恶意代码分析 动态行为分析 Lab3-1 Lab3-2 Lab3-3 Lab3-4
2022-10-01 INetSim模拟C2 这玩意比起nc来说更专业!
2022-10-01 nc这个工具用于伪造c2服务器 做c2初始连接的抓包分析实在是太tm好用了!必要时候配合APATEDNS
2022-10-01 如何查看加壳的恶意软件 Lab1-2 Lab1-3 恶意代码分析
2022-10-01 VirSCAN.org
2020-10-01 gdb list命令查看源码 break设置断点可以通过源码也可以根据汇编代码地址设置
2020-10-01 Error disabling address space randomization: Operation not permitted