PCA04-数据降噪

源文件链接:NoiseFiltering

使用PCA对手写数字进行降维

完整代码

from sklearn.datasets import load_digits
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np
digits = load_digits()
digits.data
def plot_digits(data):
    #dats的结构必须是(m,n),并且n要能够被分成(8,8)这样的结构
    fig, axes = plt.subplots(4,10,figsize=(10,4)
                             ,subplot_kw = {"xticks":[],"yticks":[]})
    for i, ax in enumerate(axes.flat):
        ax.imshow(data[i].reshape(8,8),cmap="binary")
plot_digits(digits.data)
#人为加上数据噪音
rng = np.random.RandomState(42)
#在指定的数据集中,随机抽取服从正态分布的数据
#两个参数,分别是指定的数据集,和抽取出来的正太分布的方差
noisy = np.random.normal(digits.data,2)#从输入的数据集中随机抽取一个,满足正态分布的另一个数据集,2是方差的大小
plot_digits(noisy)#画图,显示加噪音之后的数字
#进行降噪
#降噪的第一步是降维
pca = PCA(0.5,svd_solver="full").fit(noisy)#取出降维后带有原始特征50%的数
X_dr = pca.transform(noisy)
#降维之后应该包含的数图像的主要特征,不包含噪音
without_noise = pca.inverse_transform(X_dr)
#without_noise.shape由于data的结构必须是8*8的,所有不能直接打印出降噪之后的X_dr,必须先还原回去
plot_digits(without_noise)

part1 使用datasets里面的load_digits

from sklearn.datasets import load_digits
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np
digits = load_digits()
digits.data

prat2 定义一个读取图像的函数

def plot_digits(data):
    #dats的结构必须是(m,n),并且n要能够被分成(8,8)这样的结构
    fig, axes = plt.subplots(4,10,figsize=(10,4)
                             ,subplot_kw = {"xticks":[],"yticks":[]})
    #图像按照4行,每行10个数字排列
    for i, ax in enumerate(axes.flat):
        ax.imshow(data[i].reshape(8,8),cmap="binary")
plot_digits(digits.data)#打印图像

part3 加上噪音

#人为加上数据噪音
rng = np.random.RandomState(42)
#在指定的数据集中,随机抽取服从正态分布的数据
#两个参数,分别是指定的数据集,和抽取出来的正太分布的方差
noisy = np.random.normal(digits.data,2)#从输入的数据集中随机抽取一个,满足正态分布的另一个数据集,2是方差的大小
plot_digits(noisy)#画图,显示加噪音之后的数字

part4 使用PCA进行降噪(关键步骤)

#进行降噪
#降噪的第一步是降维
pca = PCA(0.5,svd_solver="full").fit(noisy)#取出降维后带有原始特征50%的数
#噪音是几乎不带有特征的,使用pca的思想就是留下方差小的,去掉方差大的
X_dr = pca.transform(noisy)
plot_digits(without_noise)

噪音是几乎不带有特征的,使用pca的思想就是留下方差小的,去掉方差大的

X_dr = pca.transform(noisy)

这一步的转换是必须的,由于在降噪的本质是在降维,而我们定义的plot_digits(data)中,data必须是8*8,即64个维度的,只有在进行升高维度之后,才能把图像打印出来

X_dr.shape

通过shape方法可以看出数据是多少维的

posted @ 2022-03-22 22:07  Brett-Xie  阅读(148)  评论(0)    收藏  举报