Matplotlib快速入门
Matplotlib
可能还有小伙伴不知道Matplotlib
是什么,下面是维基百科的介绍。
Matplotlib 是Python编程语言的一个绘图库及其数值数学扩展 NumPy。它为利用通用的图形用户界面工具包,如Tkinter, wxPython, Qt或GTK+向应用程序嵌入式绘图提供了面向对象的应用程序接口。
简单说就是画图的工具包。本文将教会你如何使用Matplotlib
绘图,如果你没有Python
基础也没关系,依葫芦画瓢也完全OK的。关于如何安装Python以及Matplotlib,文末有链接。
绘制第一个图
- 如果给
plot
函数一个一维数组,则将该数组作为纵轴坐标,并且将数组中的每个数据点索引作为水平坐标
import matplotlib.pyplot as plt
plt.plot([1, 2, 4, 9, 5, 3])
plt.show()
- 如果提供两个数组,则将其分别作为x轴和y轴
plt.plot([-3, -2, 5, 0], [1, 6, 4, 3])
plt.show()
- 坐标轴会自动匹配数据的范围,不过我们可以调用
axis
函数来改变每个轴的范围[xmin, xmax, ymin, ymax]
plt.plot([-3, -2, 5, 0], [1, 6, 4, 3])
plt.axis([-4, 6, 0, 7])
plt.show()
- 我们用NumPy's的
linspace
函数在范围[-2, 2]内创建包含500个浮点数的数组x
,计算数组x
的平方作为数组y
import numpy as np
x = np.linspace(-2, 2, 500)
y = x**2
plt.plot(x, y)
plt.show()
- 添加标题,x、y轴标签,并绘制网格
plt.plot(x, y)
plt.title("Square function")
plt.xlabel("x")
plt.ylabel("y = x**2")
plt.grid(True)
plt.show()
线条样式和颜色
- 默认情况下,matplotlib在连续的点之间绘制一条线。
plt.plot([0, 100, 100, 0, 0, 100, 50, 0, 100], [0, 0, 100, 100, 0, 100, 130, 100, 0])
plt.axis([-10, 110, -10, 140])
plt.show()
- 可在第三个参数更改线条的样式和颜色,比如“g--”表示“绿色虚线”
plt.plot([0, 100, 100, 0, 0, 100, 50, 0, 100], [0, 0, 100, 100, 0, 100, 130, 100, 0], "g--")
plt.axis([-10, 110, -10, 140])
plt.show()
- 你可以绘制多条线在同一个图上,仅仅通过重复
x1, y1, [style1], x2, y2, [style2], ...
plt.plot([0, 100, 100, 0, 0], [0, 0, 100, 100, 0], "r-", [0, 100, 50, 0, 100], [0, 100, 130, 100, 0], "g--")
plt.axis([-10, 110, -10, 140])
plt.show()
- 也可以在
show
之前plot
多次
plt.plot([0, 100, 100, 0, 0], [0, 0, 100, 100, 0], "r-")
plt.plot([0, 100, 50, 0, 100], [0, 100, 130, 100, 0], "g--")
plt.axis([-10, 110, -10, 140])
plt.show()
- 你也可以绘制点而不只是绘制直线
x = np.linspace(-1.4, 1.4, 30)
plt.plot(x, x, 'g--', x, x**2, 'r:', x, x**3, 'b^')
plt.show()
接受以下格式字符来控制线条样式或标记
character | description |
---|---|
'-' | solid line style |
'--' | dashed line style |
'-.' | dash-dot line style |
':' | dotted line style |
'.' | point marker |
',' | pixel marker |
'o' | circle marker |
'v' | triangle_down marker |
'^' | triangle_up marker |
'<' | triangle_left marker |
'>' | triangle_right marker |
'1' | tri_down marker |
'2' | tri_up marker |
'3' | tri_left marker |
'4' | tri_right marker |
's' | square marker |
'p' | pentagon marker |
'*' | star marker |
'h' | hexagon1 marker |
'H' | hexagon2 marker |
'+' | plus marker |
'x' | x marker |
'D' | diamond marker |
'd' | thin_diamond marker |
' | ' |
'_' | hline marker |
支持以下颜色缩写
character | color |
---|---|
‘b’ | blue |
‘g’ | green |
‘r’ | red |
‘c’ | cyan |
‘m’ | magenta |
‘y’ | yellow |
‘k’ | black |
‘w’ | white |
plot
函数会返回一个Line2D
对象列表,你可以额外设置一些属性,例如线的宽度,虚线风格等等。
x = np.linspace(-1.4, 1.4, 30)
line1, line2, line3 = plt.plot(x, x, 'g--', x, x**2, 'r:', x, x**3, 'b^')
line1.set_linewidth(3.0)
line1.set_dash_capstyle("round")
line3.set_alpha(0.2)
plt.show()
Line2D
属性
Property | Value Type |
---|---|
alpha | float |
animated | [True / False] |
antialiased or aa | [True / False] |
clip_box | a matplotlib.transform.Bbox instance |
clip_on | [True |
clip_path | a Path instance and a Transform instance, a Patch |
color or c | any matplotlib color |
contains | the hit testing function |
dash_capstyle | ['butt' / 'round' / 'projecting'] |
dash_joinstyle | ['miter' / 'round' / 'bevel'] |
dashes | sequence of on/off ink in points |
data | (np.array xdata, np.array ydata) |
figure | a matplotlib.figure.Figure instance |
label | any string |
linestyle or ls | [ '-' / '--' / '-.' / ':' / 'steps' / ...] |
linewidth or lw | float value in points |
lod | [True / False] |
marker | [ '+' / ',' / '.' / '1' / '2' / '3' / '4' ] |
markeredgecolor or mec | any matplotlib color |
markeredgewidth or mew | float value in points |
markerfacecolor or mfc | any matplotlib color |
markersize or ms | float |
markevery | [ None / integer / (startind, stride) ] |
picker | used in interactive line selection |
pickradius | the line pick selection radius |
solid_capstyle | ['butt' / 'round' / 'projecting'] |
solid_joinstyle | ['miter' / 'round' / 'bevel'] |
transform | a matplotlib.transforms.Transform instance |
visible | [True / False] |
xdata | np.array |
ydata | np.array |
zorder | any number |
保存图像
- 可使用savefig函数保存图像
savefig(fname, dpi=None, facecolor='w', edgecolor='w',
orientation='portrait', papertype=None, format=None,
transparent=False, bbox_inches=None, pad_inches=0.1,
frameon=None)
参数
- fname: 包含文件名的路径字符串
- dpi: [ None | scalar > 0 | ‘figure’]
- format: 文件扩展名,大多数后端支持
png, pdf, ps, eps and svg
- transparent: 如果为True则轴部分的背景透明。
示例
x = np.linspace(-1.4, 1.4, 30)
plt.plot(x, x**2)
plt.savefig("my_square_function.png", transparent=True)
组合图
- 一个图可能需要包含多个子图,那么如何操作呢。要创建子图其实只需调用子图函数,并制定图中的行数和列数,以及要绘制的子图的索引(从1开始,然后从左到右,从上到下)。注意,pyplot会跟踪当前活动的子图(您可以调用
plt.gca()
来获得引用,可以借此添加额外属性),因此当您调用绘图函数时,它会绘制活动的子图。 - 注意:
subplot(224)
是subplot(2, 2, 4)
的缩写。
x = np.linspace(-1.4, 1.4, 30)
plt.subplot(2, 2, 1) # 2 rows, 2 columns, 1st subplot = top left
plt.plot(x, x)
plt.subplot(2, 2, 2) # 2 rows, 2 columns, 2nd subplot = top right
plt.plot(x, x**2)
plt.subplot(2, 2, 3) # 2 rows, 2 columns, 3rd subplot = bottow left
plt.plot(x, x**3)
plt.subplot(224) # 2 rows, 2 columns, 4th subplot = bottom right
plt.plot(x, x**4)
plt.show()
- 创建跨多个网格单元的子图也很容易
plt.subplot(2, 2, 1) # 2 rows, 2 columns, 1st subplot = top left
plt.plot(x, x)
plt.subplot(2, 2, 2) # 2 rows, 2 columns, 2nd subplot = top right
plt.plot(x, x**2)
plt.subplot(2, 1, 2) # 2 rows, *1* column, 2nd subplot = bottom
plt.plot(x, x**3)
plt.show()
- 如果你需要更复杂的子图定位,你可以使用
subplot2grid
,你可以指定格子行数和列数,然后在格子上绘制子图(左上 = (0, 0)),并且可以指定它能跨越多少行和多少列。
plt.subplot2grid((3,3), (0, 0), rowspan=2, colspan=2)
plt.plot(x, x**2)
plt.subplot2grid((3,3), (0, 2))
plt.plot(x, x**3)
plt.subplot2grid((3,3), (1, 2), rowspan=2)
plt.plot(x, x**4)
plt.subplot2grid((3,3), (2, 0), colspan=2)
plt.plot(x, x**5)
plt.show()
- 如果需要更灵活的子图定位,看这里
绘制文本
- 你可以调用
call
在图像任意位置添加文本。仅需指定坐标选择一些额外属性。关于TeX方程表达式的细节看这里
x = np.linspace(-1.5, 1.5, 30)
px = 0.8
py = px**2
plt.plot(x, x**2, "b-", px, py, "ro")
plt.text(0, 1.5, "Square function\n$y = x^2$", fontsize=20, color='blue', horizontalalignment="center")
plt.text(px - 0.08, py, "Beautiful point", ha="right", weight="heavy") # ha是horizontalalignment的别名。
plt.text(px, py, "x = %0.2f\ny = %0.2f"%(px, py), rotation=50, color='gray')
plt.show()
- 图像元素的注释使用非常频繁,
annotate
函数使得它非常简单,只需指定兴趣点的位置、文本的位置,加上文字和箭头的一些额外属性就能完成。
plt.plot(x, x**2, px, py, "ro")
plt.annotate("Beautiful point", xy=(px, py), xytext=(px-1.3,py+0.5),
color="green", weight="heavy", fontsize=14,
arrowprops={"facecolor": "lightgreen"})
plt.show()
- 你可以使用
bbox
属性在文本周围加上框。
plt.plot(x, x**2, px, py, "ro")
bbox_props = dict(boxstyle="rarrow,pad=0.3", ec="b", lw=2, fc="lightblue")
plt.text(px-0.2, py, "Beautiful point", bbox=bbox_props, ha="right")
bbox_props = dict(boxstyle="round4,pad=1,rounding_size=0.2", ec="black", fc="#EEEEFF", lw=5)
plt.text(0, 1.5, "Square function\n$y = x^2$", fontsize=20, color='black', ha="center", bbox=bbox_props)
plt.show()
- 如果为了好玩可以绘制漫画风格的图(xkcd-style),只需在
with plt.xkcd()
内写代码就好。
with plt.xkcd():
plt.plot(x, x**2, px, py, "ro")
bbox_props = dict(boxstyle="rarrow,pad=0.3", ec="b", lw=2, fc="lightblue")
plt.text(px-0.2, py, "Beautiful point", bbox=bbox_props, ha="right")
bbox_props = dict(boxstyle="round4,pad=1,rounding_size=0.2", ec="black", fc="#EEEEFF", lw=5)
plt.text(0, 1.5, "Square function\n$y = x^2$", fontsize=20, color='black', ha="center", bbox=bbox_props)
plt.show()
图例
- 添加图例最简单的方法是在对应位置添加标签,然后调用
legend
函数
x = np.linspace(-1.4, 1.4, 50)
plt.plot(x, x**2, "r--", label="Square function")
plt.plot(x, x**3, "g-", label="Cube function")
plt.legend(loc="best")
plt.grid(True)
plt.show()
非线性尺度
- Matplotlib支持非线性尺度,例如对数或logit尺度
x = np.linspace(0.1, 15, 500)
y = x**3/np.exp(2*x)
plt.figure(1)
plt.plot(x, y)
plt.yscale('linear')
plt.title('linear')
plt.grid(True)
plt.figure(2)
plt.plot(x, y)
plt.yscale('log')
plt.title('log')
plt.grid(True)
plt.figure(3)
plt.plot(x, y)
plt.yscale('logit')
plt.title('logit')
plt.grid(True)
plt.figure(4)
plt.plot(x, y - y.mean())
plt.yscale('symlog', linthreshy=0.05)
plt.title('symlog')
plt.grid(True)
plt.show()
Ticks and tickers 刻度和刻度控制器
- "ticks"是刻度的位置 (例如 (-1, 0, 1)),"tick lines"是在这些位置绘制的小线条(刻度线),"tick labels"实在刻度线旁边绘制的标签(刻度线标签)。"tickers" 是决定在哪能放置刻度的对象,默认的tickers通常在合理的距离放置5到8个刻度。但有时候你需要控制它,幸运的是,matplotlib可以让你完全控制刻度。
x = np.linspace(-2, 2, 100)
plt.figure(1, figsize=(15,10))
plt.subplot(131)
plt.plot(x, x**3)
plt.grid(True)
plt.title("Default ticks")
ax = plt.subplot(132)
plt.plot(x, x**3)
ax.xaxis.set_ticks(np.arange(-2, 2, 1))
plt.grid(True)
plt.title("Manual ticks on the x-axis")
ax = plt.subplot(133)
plt.plot(x, x**3)
plt.minorticks_on()
ax.tick_params(axis='x', which='minor', bottom='off')
ax.xaxis.set_ticks([-2, 0, 1, 2])
ax.yaxis.set_ticks(np.arange(-5, 5, 1))
ax.yaxis.set_ticklabels(["min", -4, -3, -2, -1, 0, 1, 2, 3, "max"])
plt.title("Manual ticks and tick labels\n(plus minor ticks) on the y-axis")
plt.grid(True)
plt.show()
极坐标投影
- 绘制极坐标图,只需在创建子图时,设定"projection"属性为"polar"即可。
radius = 1
theta = np.linspace(0, 2*np.pi*radius, 1000)
plt.subplot(111, projection='polar')
plt.plot(theta, np.sin(5*theta), "g-")
plt.plot(theta, 0.5*np.cos(20*theta), "b-")
plt.show()
3D投影
- 绘制3D图像非常直接,你需要导入
Axes3D
,以注册3d
投影。然后设定projection
属性为3d
。它将返回一个Axes3DSubplot
对象,你可以调用plot_surface
,给定x,y和z坐标加额外的属性来绘制。
from mpl_toolkits.mplot3d import Axes3D
x = np.linspace(-5, 5, 50)
y = np.linspace(-5, 5, 50)
X, Y = np.meshgrid(x, y)
R = np.sqrt(X**2 + Y**2)
Z = np.sin(R)
figure = plt.figure(1, figsize = (12, 4))
subplot3d = plt.subplot(111, projection='3d')
surface = subplot3d.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=matplotlib.cm.coolwarm, linewidth=0.1)
plt.show()
- 显示相同数据的另外一种方法是用等高线图
plt.contourf(X, Y, Z, cmap=matplotlib.cm.coolwarm)
plt.colorbar()
plt.show()
散点图
- 提供x和y的坐标就可以绘制散点图。
from numpy.random import rand
x, y = rand(2, 100)
plt.scatter(x, y)
plt.show()
- 你也可以提供每个点的比例
x, y, scale = rand(3, 100)
scale = 500 * scale ** 5
plt.scatter(x, y, s=scale)
plt.show()
- 还可以设置些额外的属性,例如填充颜色、边缘颜色、透明度。
for color in ['red', 'green', 'blue']:
n = 100
x, y = rand(2, n)
scale = 500.0 * rand(n) ** 5
plt.scatter(x, y, s=scale, c=color, alpha=0.3, edgecolors='blue')
plt.grid(True)
plt.show()
直线(工具函数)
- 创建一个工具函数来画图通常会更方便,该函数会在给定斜率和截距的情况下在图上绘制一条看似无限长的线。
from numpy.random import randn
def plot_line(axis, slope, intercept, **kargs):
xmin, xmax = axis.get_xlim()
plt.plot([xmin, xmax], [xmin*slope+intercept, xmax*slope+intercept], **kargs)
x = randn(1000)
y = 0.5*x + 5 + randn(1000)*2
plt.axis([-2.5, 2.5, -5, 15])
plt.scatter(x, y, alpha=0.2)
plt.plot(1, 0, "ro")
plt.vlines(1, -5, 0, color="red")
plt.hlines(0, -2.5, 1, color="red")
plot_line(axis=plt.gca(), slope=0.5, intercept=5, color="magenta")
plt.grid(True)
plt.show()
直方图
data = [1, 1.1, 1.8, 2, 2.1, 3.2, 3, 3, 3, 3]
plt.subplot(211)
plt.hist(data, bins = 10, rwidth=0.8)
plt.subplot(212)
plt.hist(data, bins = [1, 1.5, 2, 2.5, 3], rwidth=0.95)
plt.xlabel("Value")
plt.ylabel("Frequency")
plt.show()
data1 = np.random.randn(400)
data2 = np.random.randn(500) + 3
data3 = np.random.randn(450) + 6
data4a = np.random.randn(200) + 9
data4b = np.random.randn(100) + 10
plt.hist(data1, bins=5, color='g', alpha=0.75, label='bar hist') # default histtype='bar'
plt.hist(data2, color='b', alpha=0.65, histtype='stepfilled', label='stepfilled hist')
plt.hist(data3, color='r', histtype='step', label='step hist')
plt.hist((data4a, data4b), color=('r','m'), alpha=0.55, histtype='barstacked', label=('barstacked a', 'barstacked b'))
plt.xlabel("Value")
plt.ylabel("Frequency")
plt.legend()
plt.grid(True)
plt.show()
图像
- 读取图像;仅需导入
matplotlib.image
moudle,再调用imread
函数(传入文件名),它将以NumPy's数组形式返回图像数据。
import matplotlib.image as mpimg
img = mpimg.imread('my_square_function.png')
print(img.shape, img.dtype)
# Out:(288, 432, 4) float32
- 生成图像也很简单
img = np.arange(100*100).reshape(100, 100)
print(img)
plt.imshow(img)
plt.show()
- 由于我们没有提供RGB等级,
imshow
函数自动将值映射到颜色渐变。默认情况,颜色渐变从蓝色(低值)变为红色(高值),但你可以选择其他颜色映射。例如:
plt.imshow(img, cmap="hot")
plt.show()
- 你也可以直接产生RGB图像
img = np.empty((20,30,3))
img[:, :10] = [0, 0, 0.6]
img[:, 10:20] = [1, 1, 1]
img[:, 20:] = [0.6, 0, 0]
plt.imshow(img)
plt.show()
- 由于img数组非常小(20x30),所以当imshow函数显示它时,它会将图像增大图像的大小。 默认情况下,它使用双线性插值来填充所添加的像素。 这就是边缘看起来模糊的原因。 您可以选择另一种插值算法,例如复制最近像素的颜色:
plt.imshow(img, interpolation="nearest")
plt.show()
动画
绘制
FuncAnimation
构造函数接受一个图形,一个更新函数和可选参数。 我们指定我们需要一个100帧长的动画,每帧之间有20ms。 在每次迭代中,FuncAnimation调用我们的更新函数,并将帧号传递给num
(在我们的例子中是从0到99),接着是我们用fargs
指定额外的参数。- 我们的更新函数简单地将行数据设置为第一个数据点(所以数据将逐渐绘制),并且为了好玩,我们还为每个数据点添加一个小的随机数,这样该行似乎在摆动。
import matplotlib.animation as animation
x = np.linspace(-1, 1, 100)
y = np.sin(x**2*25)
data = np.array([x, y])
fig = plt.figure()
line, = plt.plot([], [], "r-") # start with an empty plot
plt.axis([-1.1, 1.1, -1.1, 1.1])
plt.plot([-0.5, 0.5], [0, 0], "b-", [0, 0], [-0.5, 0.5], "b-", 0, 0, "ro")
plt.grid(True)
plt.title("Marvelous animation")
# this function will be called at every iteration
def update_line(num, data, line):
line.set_data(data[..., :num] + np.random.rand(2, num) / 25) # we only plot the first `num` data points.
return line,
line_ani = animation.FuncAnimation(fig, update_line, frames=100, fargs=(data, line), interval=67)
plt.show()
保存
- Matplotlib依靠第三方库来编写视频,如FFMPEG或mencoder。 在这个例子中,我们将使用FFMPEG,所以一定要先安装它。
Writer = animation.writers['ffmpeg']
writer = Writer(fps=15, metadata=dict(artist='Me'), bitrate=1800)
line_ani.save('my_wiggly_animation.mp4', writer=writer)
更多
欢迎关注公众号:MachineEpoch,与新人一起学习。