matplotlib绘图学习

最近频繁用到matplotlib绘图,梳理了下官网的tutorial,记录下学习笔记。主要是对下面链接的翻译和个人理解整理。

https://matplotlib.org/3.5.0/tutorials/introductory/usage.html

1. 基础知识

matplotlib绘图中包括两个概念:figure, axes. 其中figure表示一张图,axes对应一个绘图区域,一个figure中可以包括多个axes, 如下代码:

fig = plt.figure()  # an empty figure with no Axes
fig, ax = plt.subplots()  # a figure with a single Axes
fig, axs = plt.subplots(2, 2)  # a figure with a 2x2 grid of Axes

一个figure可以包含的组成成分如下:

注意:matplotlib绘制图片时,接受的数据最好是numpy格式(list和pandas等array-like数据,可能会出现异常情况)

1.1 面向对象接口和pyplot接口

matplotlib提供了两种形式的接口供调用:

  • 面向对象接口(object-oriented interface): 显示的创建figure对象和axes对象,调用figure和axes对象的方法
  • pyplot接口(pyplot interface):依靠pyplot来创建和管理figure和axes对象,调用pyplot函数

面向对象接口风格代码如下:

x = np.linspace(0, 2, 100)  # Sample data.

# Note that even in the OO-style, we use `.pyplot.figure` to create the figure.
fig, ax = plt.subplots()  # Create a figure and an axes.
ax.plot(x, x, label='linear')  # Plot some data on the axes.
ax.plot(x, x**2, label='quadratic')  # Plot more data on the axes...
ax.plot(x, x**3, label='cubic')  # ... and some more.
ax.set_xlabel('x label')  # Add an x-label to the axes.
ax.set_ylabel('y label')  # Add a y-label to the axes.
ax.set_title("Simple Plot")  # Add a title to the axes.
ax.legend()  # Add a legend.

pyplot接口风格代码:

x = np.linspace(0, 2, 100)  # Sample data.

plt.plot(x, x, label='linear')  # Plot some data on the (implicit) axes.
plt.plot(x, x**2, label='quadratic')  # etc.
plt.plot(x, x**3, label='cubic')
plt.xlabel('x label')
plt.ylabel('y label')
plt.title("Simple Plot")
plt.legend()

2. pyplot接口

matplotlib.pyplot 是一个使matplotlib像MATLAB一样工作的绘图接口(很多函数的集合),pyplot会自动追踪当前figure和axes, 其调用函数也是作用于当前axes。

示例一

import matplotlib.pyplot as plt
import numpy as np
# 1. 定义一个图形窗口
plt.figure()

# 2. 绘制图形
plt.plot([1, 3, 5, 4])   # [1, 3, 5, 4]会被当作y,x会被自动设置成[0, 1, 2,3],(x从0开始递增)
plt.ylabel('some label')

plt.plot([1, 2, 3, 4], [1, 3, 5, 4])   #x=[1, 2, 3, 4], y=[1, 3, 5, 4]

# 3. 显示绘制图形
plt.show()

2.1 控制plot曲线的格式(style)

对于每一组x,y, plot函数接受一个字符串参数fmt,设置绘制曲线的格式,其中fmt格式如下:

fmt = '[marker][line][color]'     
或者fmt = '[color][marker][line]'
fmt默认设置:'b-

marker:

character description
'.' point marker
',' pixel marker
'o' circle marker
'v' triangle_down marker

line:

character description
'-' solid line style
'--' dashed line style
'-.' dash-dot line style
':' dotted line style

color:

character color
'b' blue
'g' green
'r' red

所有marker, line, color参考如下链接:https://matplotlib.org/3.5.0/api/_as_gen/matplotlib.pyplot.plot.html#matplotlib.pyplot.plot

常用fmt示例:

'b'    # blue markers with default shape
'or'   # red circles
'-g'   # green solid line
'--'   # dashed line with default color
'^k:'  # black triangle_up markers connected by a dotted line

示例一

import matplotlib.pyplot as plt

plt.figure()
# 绘制红色的圆圈
plt.plot([1, 2, 3, 4], [1, 4, 9, 16], 'ro')
plt.show()

示例二

import matplotlib.pyplot as plt
import numpy as np

plt.figure()
t = np.arange(0., 5., 0.2)
# 'r--'红色虚线; 'bs':蓝色的方框('s':square marker); 'g^':绿色上三角形('^': 上三角形marker)
plt.plot(t, t, 'r--', t, t**2, 'bs', t, t**3, 'g^')

plt.show()

2.3 plot接受字典和字符串形式输入

示例一

plt.scatter()函数参数data,可以接受一个字典输入,根据字典的key来索引字典中的数据,如下代码:

import matplotlib.pyplot as plt
import numpy as np

plt.figure()
data = {'a': np.arange(50),
        'c': np.random.randint(0, 50, 50),
        'd': np.random.randn(50),
}
data['b'] = data['a']+10*np.random.randn(50)
data['d'] = np.abs(data['d'])*100
# 采用data['a'],data['b']表示每个点的x,y
# c='c', 表示采用data['c']的值设置每个点的颜色
# s='d', 表示采用data['d']的值设置每个点的大小
plt.scatter('a', 'b', c='c', s='d', data=data)
plt.xlabel('entry a')
plt.ylabel('entry b')
plt.show()

示例二

plt.plot()的横坐标可以接收字符串形式的输入,如下:

import matplotlib.pyplot as plt
import numpy as np

names = ['group_a', 'group_b', 'group_c']
values = [1, 10, 100]
plt.figure(figsize=(9, 3))

# 131: 表示有1行3列,共3个子图,在3个子图的第一个子图中绘制
plt.subplot(131)
plt.bar(names, values)

plt.subplot(132)
plt.scatter(names, values)

plt.subplot(133)
plt.plot(names, values)

plt.show()

2.4 设置line的属性

直线有很多属性可以设置,线的宽度,颜色,风格等,有下面三种设置方式,如下面代码:

import matplotlib.pyplot as plt
import numpy as np
plt.figure()

x1 = np.arange(10)
x2 = np.arange(10)+np.random.rand(10)
y1 = x1**2
y2 = x2+np.random.randint(3, 10)
# 方式一:通过关键字参数设置
# plt.plot(x1, y1, linewidth=5.0)

# 方式二:通过返回的line2D对象设置  (返回包含line2D对象的列表,有两条直线,所以列表里有两个对象)
# line1, line2 = plt.plot(x1, y1, '-', x2, y2)
# line1.set_antialiased(False)  # turn off antialiasing

lines = plt.plot(x1, y1, x2, y2)
# 方式三:通过plt.setp设置
plt.setp(lines, color='r', linewidth=5.0)
plt.setp(lines)  # plt.setp(lines):查看所有可以设置的属性名称
plt.show()

2.5 多个figure和多个axes

pyplot中有figure和axes的概念,可以有多个figure,每个figure可以有多个axes,pyplot.gca()返回当前的axes对象( matplotlib.axes.Axes ), pyplot.gcf()返回当前的figure对象( matplotlib.figure.Figure ),pyplot总是绘制在当前axes。

下面代码中利用plot.subplot()创建多个axes,其参数含义如下:

plot.subplot(numrows, numcols, plot_number):

  • subplot(211): 等同于subplot(2, 1, 1), 表示创建2行1列,共2个axes, 在其中的第一个axes进行绘制

示例1

import matplotlib.pyplot as plt
import numpy as np

def func(t):
    return np.exp(-t)*np.cos(2*np.pi*t)

t1 = np.arange(0.0, 5.0, 1.0)
t2 = np.arange(0.0, 5.0, 0.02)
plt.figure()
plt.subplot(211)   # 相当于plt.subplot(2, 1, 1), 第一个axes
plt.plot(t1, func(t1), 'bo', t2, func(t2), 'k')

plt.subplot(212)  # 相当于plt.subplot(2, 1, 2),第二个axes
plt.plot(t2, np.cos(2*np.pi*t2), 'r--')

plt.show()

示例2

下面示例代码中,绘制了两个figure, 每个figure中有自己的多个axes:

import matplotlib.pyplot as plt
import numpy as np

plt.figure(1)                # the first figure
plt.subplot(211)             # the first subplot in the first figure
plt.plot([1, 2, 3])
plt.subplot(212)             # the second subplot in the first figure
plt.plot([4, 5, 6])

plt.figure(2)                # a second figure
plt.plot([4, 5, 6])          # creates a subplot() by default

plt.figure(1)                # figure 1 current; subplot(212) still current
plt.subplot(211)             # make subplot(211) in figure1 current
plt.title('Easy as 1, 2, 3') # subplot 211 title

plt.show()

plt.clf() 清理当前figure中内容
plt.cla() # 清理当前axes中内容
plt.close(2) # 关闭figure2

2.6 pyplot中text

pyplot中可以通过如下函数添加文本:

  • plt.xlabel('Smarts'): x坐标轴名称
  • plt.ylabel('Probability'): y坐标轴名称
  • plt.title('Histogram of IQ'): axes标题
  • plt.text(60, .025, r'$\mu=100,\ \sigma=15$'): 在坐标(60, 0.025)处添加文本 ($$ 表示markdown文本格式)

上面的四个text中都支持markdown格式,来设置数学表达式,如:

plt.title(r'$\sigma_i=15$')

示例1

import matplotlib.pyplot as plt
import numpy as np
plt.figure()
mu, sigma = 100, 15
x = mu + sigma*np.random.randn(10000)
n, bins, patches = plt.hist(x, 50, density=1, facecolor='g', alpha=0.75)
plt.xlabel('Smart')
plt.ylabel('Probability')
plt.title('Histogram of IQ')
plt.text(60,  0.025, r'$\mu=100,\ \sigma=15$')
plt.axis([40, 160, 0, 0.03])   # [xmin, xmax, ymin, ymax]
plt.grid(True)  # 设置网格
plt.show()

和之前介绍的line属性设置一样,也可以通过关键字,或者pyplot.setp()来设置文本的属性,如下:

t1 = plt.xlabel('my data', fontsize=14, color='red')
# plt.setp(t1)  # 打印t1可以设置的属性
t2 = plt.text(60,  0.025, r'$\mu=100,\ \sigma=15$', color='green')
# plt.setp(t2)   # 打印t2可以设置的属性

示例2

plt.annotate(text='local max', xy=(2, 1), xytext=(3, 1.5),):可以设置文本及其指向的位置

  • text: 文本字符串
  • xy: 文本指向的坐标
  • xytext:文本的坐标位置

annotate详细文档:https://matplotlib.org/3.5.0/tutorials/text/annotations.html#annotations-tutorial

plt.figure()
ax =plt.subplot()
t = np.arange(0, 5.0, 0.01)
s = np.cos(2*np.pi*t)
line, = plt.plot(t, s, lw=2)
plt.annotate('local max', xy=(2, 1), xytext=(3, 1.5),
             arrowprops=dict(facecolor='black', shrink=0.05),
             )

plt.ylim(-2, 2)
plt.show()

2.7 Logarithmic scale

pyplot的坐标轴除了支持linear scale,还支持log scale, symmetric log scale, logit scale:

关于scale文档:https://matplotlib.org/3.5.0/api/scale_api.html#module-matplotlib.scale

  • log scale: 即matplotlib.scale.LogScale, log坐标轴,注意其只绘制正数,会忽略负数

  • symmentric log scale: 即matplotlib.scale.SymmetricalLogScale, 对称log坐标轴, 支持正负数

  • logit scale: 即matplotlib.scale.LogitScale, logit坐标轴,[0,1]范围内,会将log数据映射[0, 1]范围内。

    ( logit = 1/(1+log(-x)) )

示例代码如下:

np.random.seed(19680801)
y = np.random.normal(loc=0.5, scale=0.4, size=1000)  # shape(1000,)
y = y[(y > 0) & (y < 1)]  # shape(799, ), 值在[0, 1]范围内
y.sort()
x = np.arange(len(y))

plt.figure()
# 线性坐标轴
plt.subplot(221)
plt.plot(x, y)
plt.yscale('linear')
plt.title('linear')
plt.grid(True)

# y轴为log坐标轴
plt.subplot(222)
plt.plot(x, y)
plt.yscale('log')
plt.title('log')
plt.grid(True)

# y轴为symlog坐标轴
plt.subplot(223)
plt.plot(x,  y - y.mean())  # y可以为正数和负数,均值变成0
plt.yscale('symlog', linthresh=0.01)  # x趋近于0时,log(x)会趋近于负无穷,linthresh=0.01设置(-0.01, 0.01)范围内为线性值
plt.title('symlog')
plt.grid(True)

# y轴为logit坐标轴
plt.subplot(224)
plt.plot(x, y)
plt.yscale('logit')
plt.title('logit')
plt.grid(True)

# 调整subplot之间的格式
plt.subplots_adjust(top=0.92, bottom=0.08, left=0.10, right=0.95, hspace=0.25,
                    wspace=0.35)
plt.show()

2. 8 综合案例

下面时工作中我绘制的一段代码截图,数据采用随机数进行了代替,包括了坐标轴的标题,图例显示,坐标轴刻度范围等。

import numpy as np
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 6))  # 设置fig的大小,width=12, height=6

# quaternion linear曲线插值
quater_linear_joint_y = np.random.randn(19)+0.2   # shape(19, )
quater_linear_sequence_y = np.random.randn(481)+0.2  # shape(481, )

# quaternion bezier曲线插值
quater_bezier_joint_y = np.random.randn(19)    # shape(19, )
quater_bezier_sequence_y = np.random.randn(481)   # shape(481, )

quater_linear_joint_x = np.arange(quater_linear_joint_y.shape[0]) + 1    # shape(19, )
quater_bezier_joint_x = np.arange(quater_bezier_joint_y.shape[0]) + 1     # shape(19, )
quater_linear_sequence_x = np.arange(quater_linear_sequence_y.shape[0]) + 1    # shape(481, )
quater_bezier_sequence_x = np.arange(quater_bezier_sequence_y.shape[0]) + 1   # shape(481, )

# 线性插值
pos3d_linear_joint_y = np.random.randn(19)+0.1   # shape(19, )
pos3d_linear_sequence_y = np.random.randn(481)+0.1  # shape(481, )

# bezier曲线插值
pos3d_bezier_joint_y = np.random.randn(19)    # shape(19, )
pos3d_bezier_sequence_y = np.random.randn(481)   # shape(481, )

pos3d_linear_joint_x = np.arange(pos3d_linear_joint_y.shape[0]) + 1   # shape(19, )
pos3d_bezier_joint_x = np.arange(pos3d_bezier_joint_y.shape[0]) + 1   # shape(19, )
pos3d_linear_sequence_x = np.arange(pos3d_linear_sequence_y.shape[0]) + 1  # shape(481, )
pos3d_bezier_sequence_x = np.arange(pos3d_bezier_sequence_y.shape[0]) + 1  # shape(481, )

# 绘制joint
plt.subplot(121)
plt.title(f"difference in joint")  # 设置subplot标题
plt.xlabel("joint number")    # 设置y轴标题
plt.ylabel("l2 distance")   # 设置x轴标题
ymin = min(np.min(quater_linear_joint_y), np.min(quater_bezier_joint_y),
           np.min(pos3d_linear_joint_y), np.min(pos3d_bezier_joint_y))
ymax = max(np.max(quater_linear_joint_y), np.max(quater_bezier_joint_y),
           np.max(pos3d_linear_joint_y), np.max(pos3d_bezier_joint_y))
plt.ylim((ymin, ymax))   # 设置y轴坐标轴刻度范围
plt.xlim((1, 19))       # 设置x轴坐标轴刻度范围
plt.plot(quater_linear_joint_x, quater_linear_joint_y, 'rs', label='quaternion linear')   #'rs'表示红色方框,label方便legend显示
plt.plot(quater_bezier_joint_x, quater_bezier_joint_y, 'r^', label='quaternion bezier')   #'r^'表示红色三角形,label方便legend显示

plt.plot(pos3d_linear_joint_x, pos3d_linear_joint_y, 'bs', label='pos3d linear')   #'bs'表示蓝方框,label方便legend显示
plt.plot(pos3d_bezier_joint_x, pos3d_bezier_joint_y, 'b^', label='pos3d bezier')   #'b^'表示蓝色三角形,label方便legend显示
plt.legend()  # 显示图例,根据plot时设置的label区分

# 绘制sequence
plt.subplot(122)
plt.title(f"difference in frames")
plt.xlabel("frames number")
plt.ylabel("l2 distance")
ymin = min(np.min(quater_linear_sequence_y), np.min(quater_bezier_sequence_y),
           np.min(pos3d_linear_sequence_y), np.min(pos3d_bezier_sequence_y))
ymax = max(np.max(quater_linear_sequence_y), np.max(quater_bezier_sequence_y),
           np.max(pos3d_linear_sequence_y), np.max(pos3d_bezier_sequence_y))
plt.ylim((ymin, ymax))   # 设置y轴坐标轴刻度范围
plt.xlim((1, 481))
plt.plot(quater_linear_sequence_x, quater_linear_sequence_y, 'rs', label='quaternion linear')
plt.plot(quater_bezier_sequence_x, quater_bezier_sequence_y, 'r^', label='quaternion bezier')

plt.plot(pos3d_linear_sequence_x, pos3d_linear_sequence_y, 'bs', label='pos3d linear')
plt.plot(pos3d_bezier_sequence_x, pos3d_bezier_sequence_y, 'b^', label='pos3d bezier')
plt.legend()
# plt.savefig(f'./plot.png')  # 保存绘制图片
plt.show()

3. 图片显示

matplotlib.image可以用来读取图片为numpy数据(底层依赖Pillow),其读取的图片numpy格式为RGB, 通过plt.imshow能显示图片numpy数据。

注意:matplolib读取图片后,会将其缩放到[0, 1]范围内,并且转换为float32格式, 对应RGB图片matplotlib支持float32和uint8类型数据, 但对于灰度图,matplotlib只支持float32格式

(若采用opencv读取,需要将BGR其转换为RGB,并将其转换为float32类型,缩放到[0,1]范围)

示例代码如下:

import matplotlib.pyplot as plt
import matplotlib.image as mping

img_path = r"./stinkbug.png"
img = mping.imread(img_path)
print(img.shape, img.dtype)
imgplot = plt.imshow(img)
plt.show()

3.1 伪彩色机制

对于灰度图图片,matplotlib可以通过伪彩色机制显示,能够分辨出图片中亮暗区域。如下代码所示:

import matplotlib.pyplot as plt
import matplotlib.image as mping

img_path = r"./stinkbug.png"
img = mping.imread(img_path)
print(img.shape, img.dtype)

# example 2
gray_img = img[:, :, 0]
plt.imshow(gray_img)   # 默认采用viridis伪彩色展示灰度图
plt.colorbar()  # 显示颜色条
plt.show()

默认采用viridis伪彩色展示灰度图,可以设置伪彩色的机制, 如下所示:

import matplotlib.pyplot as plt
import matplotlib.image as mping

img_path = r"./stinkbug.png"
img = mping.imread(img_path)
print(img.shape, img.dtype)

gray_img = img[:, :, 0]
plt.imshow(gray_img, cmap='hot')   # 采用hot伪彩色展示灰度图

# plt.imshow(gray_img, cmap='gray')   # 展示原始灰度图

# imgplot = plt.imshow(gray_img)
# imgplot.set_cmap('nipy_spectral') # 采用nipy_spectral伪彩色展示灰度图

plt.colorbar()  # 显示颜色条

3.2 直方图统计和截取

plt.hist能绘制直方图,统计每个像素出现的次数,plt.imshow函数中的clim参数能设置显示的像素区间, 如下面代码:

plt.hist(lum_img.ravel(), bins=256, range=(0.0, 1.0), fc='k', ec='k')  # 绘制单通道灰度图直方图,256个区间
imgplot = plt.imshow(lum_img, clim=(0.0, 0.7))  # 图像像素范围在[0, 1]区间,只显[0, 0.7]区间像素

示例1

下面代码中,截取了原图像素(0, 0.7)的区间。(根据统计直方图可知,超过0.7范围内的像素很少, 相当于增加对比度?)

import matplotlib.pyplot as plt
import matplotlib.image as mping

img_path = r"./stinkbug.png"
img = mping.imread(img_path)
gray_img = img[:, :, 0]

fig = plt.figure()
ax = fig.add_subplot(1, 2, 1)
imgplot = plt.imshow(gray_img)
ax.set_title('Before')
plt.colorbar(ticks=[0.1, 0.3, 0.5, 0.7], orientation='horizontal')

ax2 = fig.add_subplot(1, 2, 2)
imgplot2 = plt.imshow(gray_img, clim=(0.0, 0.7))
ax2.set_title('After')
plt.colorbar(ticks=[0.1, 0.3, 0.5, 0.7], orientation='horizontal')

plt.show()

3.3 插值

plt.imshow显示图片时,若图片尺寸发生变化,可以设置插值的方式, 示例代码如下:

import matplotlib.pyplot as plt
from PIL import Image

img_path = r"./stinkbug.png"
img = Image.open(img_path)
print(img.size)  # shape(500, 375)
img.thumbnail((64, 64), Image.ANTIALIAS)   # 保持长宽比缩放: shape(64, 48)
print(img.size)
# plt.imshow(img)   # 默认采用bilinear 插值
plt.imshow(img, interpolation='nearest')   # 采用nearest插值
# plt.imshow(img, interpolation='bicubic')  # 采用bicubic插值

plt.show()

posted @ 2022-04-13 13:37  silence_cho  阅读(406)  评论(0编辑  收藏  举报