局部放大画图
学姐改好的,存一下以便以后用。主要用于训练到最后看不清末端的差异,放大局部的。
import random
import csv
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
import numpy as np
def zone_and_linked(ax,axins,zone_left,zone_right,x,y,linked='bottom',
x_ratio=0.05,y_ratio=0.05):
"""缩放内嵌图形,并且进行连线
ax: 调用plt.subplots返回的画布。例如: fig,ax = plt.subplots(1,1)
axins: 内嵌图的画布。 例如 axins = ax.inset_axes((0.4,0.1,0.4,0.3))
zone_left: 要放大区域的横坐标左端点
zone_right: 要放大区域的横坐标右端点
x: X轴标签
y: 列表,所有y值
linked: 进行连线的位置,{'bottom','top','left','right'}
x_ratio: X轴缩放比例
y_ratio: Y轴缩放比例
"""
xlim_left = x[zone_left]-(x[zone_right]-x[zone_left])*x_ratio
xlim_right = x[zone_right]+(x[zone_right]-x[zone_left])*x_ratio
y_data = np.hstack([yi[zone_left:zone_right] for yi in y])
ylim_bottom = np.min(y_data)-(np.max(y_data)-np.min(y_data))*y_ratio
ylim_top = np.max(y_data)+(np.max(y_data)-np.min(y_data))*y_ratio
axins.set_xlim(xlim_left, xlim_right)
axins.set_ylim(ylim_bottom, ylim_top)
ax.plot([xlim_left,xlim_right,xlim_right,xlim_left,xlim_left],
[ylim_bottom,ylim_bottom,ylim_top,ylim_top,ylim_bottom],"black")
if linked == 'bottom':
xyA_1, xyB_1 = (xlim_left,ylim_top), (xlim_left,ylim_bottom)
xyA_2, xyB_2 = (xlim_right,ylim_top), (xlim_right,ylim_bottom)
elif linked == 'top':
xyA_1, xyB_1 = (xlim_left,ylim_bottom), (xlim_left,ylim_top)
xyA_2, xyB_2 = (xlim_right,ylim_bottom), (xlim_right,ylim_top)
elif linked == 'left':
xyA_1, xyB_1 = (xlim_right,ylim_top), (xlim_left,ylim_top)
xyA_2, xyB_2 = (xlim_right,ylim_bottom), (xlim_left,ylim_bottom)
elif linked == 'right':
xyA_1, xyB_1 = (xlim_left,ylim_top), (xlim_right,ylim_top)
xyA_2, xyB_2 = (xlim_left,ylim_bottom), (xlim_right,ylim_bottom)
con = ConnectionPatch(xyA=xyA_1,xyB=xyB_1,coordsA="data",
coordsB="data",axesA=axins,axesB=ax)
axins.add_artist(con)
con = ConnectionPatch(xyA=xyA_2,xyB=xyB_2,coordsA="data",
coordsB="data",axesA=axins,axesB=ax)
axins.add_artist(con)
# 读取文件
data = []
for i in range(1,5):
data.append([])
path = './mnist/'+str(i-1)+'/acc_test.csv'
reader = csv.reader(open(path))
for item in reader:
data[i-1].append(float(item[0]))
plt.rc('font',family='Times New Roman')
plt.figure(figsize=(10, 4))
width = 0.25 # 条形图的宽度
# x = [str(i) for i in range(0,201,25)]
x = [i for i in range(0,201,25)]
columns = [i for i in range(0,200)]#轮次
font1 = {'family' : 'Times New Roman',
'weight' : 'normal',
'size' : 14,
}
# 绘图
fig, ax = plt.subplots()
ax.plot(columns, data[0],color="#7DABCF",label='20',marker='o',markersize=3,markerfacecolor='none')
ax.plot(columns, data[1],color="#AAB083",label='15',marker='x',markersize=3,markerfacecolor='none')
ax.plot(columns, data[2],color="#FBC1AD",label='25',marker='v',markersize=3,markerfacecolor='none')
ax.plot(columns, data[3],color="#ABC1AD",label='30',marker='s',markersize=3,markerfacecolor='none')
# plt.plot(columns, scaffold,color="C2",label='Scaffold',marker='v',markersize=5,markerfacecolor='none')
"""
缩放图
"""
# 绘制缩放图
axins1 = ax.inset_axes((0.5, 0.45, 0.3, 0.3))
# 在缩放图中也绘制主图所有内容,然后根据限制横纵坐标来达成局部显示的目的
axins1.plot(columns, data[0],color="#7DABCF",label='20',marker='o',markersize=3,markerfacecolor='none')
axins1.plot(columns, data[1],color="#AAB083",label='15',marker='x',markersize=3,markerfacecolor='none')
axins1.plot(columns, data[2],color="#FBC1AD",label='25',marker='v',markersize=3,markerfacecolor='none')
axins1.plot(columns, data[3],color="#ABC1AD",label='30',marker='s',markersize=3,markerfacecolor='none')
zone_left=180
zone_right=199
# 局部显示并且进行连线
zone_and_linked(ax, axins1, zone_left, zone_right, columns , [data[0],data[1],data[2],data[3]], 'bottom')
# 局部显示并且进行连线
x1=columns
y1=[data[0],data[1],data[2],data[3]]
x_ratio=0.02 # 0.02
y_ratio=0.02 # 0.02
xlim_left = x1[zone_left] - (x1[zone_right] - x1[zone_left]) * x_ratio
xlim_right = x1[zone_right] + (x1[zone_right] - x1[zone_left]) * x_ratio
y_data = np.hstack([yi[zone_left:zone_right] for yi in y1])
ylim_bottom = np.min(y_data)-y_ratio
ylim_top = np.max(y_data)+y_ratio
axins1.set_xlim(xlim_left, xlim_right)
axins1.set_ylim(ylim_bottom, ylim_top)
# ylim_top = np.max(y_data)
# ylim_bottom = np.min(y_data)
# ax.plot([xlim_left, xlim_right, xlim_right, xlim_left, xlim_left],
# [ylim_bottom, ylim_bottom, ylim_top, ylim_top, ylim_bottom], "black")
ylim_top = np.max(y_data)
ylim_bottom = np.min(y_data)-y_ratio
xyA_1, xyB_1 = (xlim_left, ylim_bottom), (xlim_left, ylim_top)
xyA_2, xyB_2 = (xlim_right, ylim_bottom), (xlim_right, ylim_top)
con = ConnectionPatch(xyA=xyA_1, xyB=xyB_1, coordsA="data",
coordsB="data", axesA=axins1, axesB=ax)
axins1.add_artist(con)
con = ConnectionPatch(xyA=xyA_2, xyB=xyB_2, coordsA="data",
coordsB="data", axesA=axins1, axesB=ax)
axins1.add_artist(con)
"""
缩放图
"""
plt.gca().set_aspect(100)#改变xy轴长宽比例
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
# ax.set_ylim(0,1)
# ax.set_yticks([round(i,1) for i in np.linspace(0,1,11)])
ax.set_ylim(0,1)
ax.set_yticks([round(i,2) for i in np.linspace(0,1,11)])
ax.set_xticks(x) # 设置刻度标签。
ax.set_ylabel(r"Accuracy",fontsize=14)
ax.set_xlabel(r"Communication rounds",fontsize=14)
ax.set_title('')
ax.legend(prop=font1)
plt.grid(axis='both', linestyle='--', linewidth=0.5)
plt.savefig("acc_cluster.pdf")
plt.show()
我是咸鱼。转载博客请征得博主同意Orz