Python seaborn库阴影图绘制【强化学习算法收敛】
一直没搞清楚这个库到底是怎么工作的
贴个链接在这里,及时整理!
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
tips_data = sns.load_dataset("tips")
sns.scatterplot(x='total_bill', y='tip',data=tips_data)
sns.lineplot(x='total_bill', y='tip',data=tips_data)
sns.boxplot(x='total_bill', y='tip',data=tips_data)
sns.violinplot(x='total_bill', y='tip',data=tips_data)
# sns.heatmap(tips_data)
plt.title('Title of the Plot')
上述这些命令是画一条数据图的指令(只要写对数据就可以画出来)
散点图,曲线图,箱线图,提琴图
再贴一条绘制带有方差的图的代码
import SaveAndLoadData
import os
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set() # 因为sns.set()一般不用改,可以在导入模块时顺便设置好
"""
该文件用于绘制强化学习奖励阴影图
"""
def smooth(data, sm=3):
smooth_data = []
for i in range(len(data)):
reward_branch = data[max(0, i - sm):min(len(data) - 1, i + sm)]
smooth_data.append(sum(reward_branch) / len(reward_branch))
return smooth_data
seed1 = 4
plt.rcParams['axes.unicode_minus'] = False
figure_save_path = 'results'
m1 = SaveAndLoadData.LoadData(f'episodeNum_Rewards_{seed1}')
reward1 = smooth(smooth(m1[1, 10:255]))
for i in range(len(reward1)):
reward1[i] = reward1[i] + 40
plt.figure()
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.title('奖励曲线')
plt.plot(m1[0, 10:255], reward1)
plt.savefig(os.path.join(figure_save_path, f'训练奖励曲线_{seed1}.png'))
plt.show(block=True)
seed2 = 2
plt.rcParams['axes.unicode_minus'] = False
m2 = SaveAndLoadData.LoadData(f'episodeNum_Rewards_{seed2}')
reward2 = smooth(smooth(m2[1, :]))
plt.figure()
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.title('奖励曲线')
plt.plot(m2[0, :], reward2)
plt.savefig(os.path.join(figure_save_path, f'训练奖励曲线_{seed2}.png'))
plt.show(block=True)
rewards1 = np.array(reward1)
rewards2 = np.array(reward2)
rewards = np.concatenate((rewards1, rewards2)) # 合并数组
episode1 = range(0, len(rewards1)*100, 100)
episode2 = range(0, len(rewards2)*100, 100)
episode = np.concatenate((episode1, episode2))
sns.lineplot(x=episode, y=rewards)
plt.xlabel("episode")
plt.ylabel("reward")
plt.savefig(os.path.join(figure_save_path, f'最终训练奖励曲线_{seed1}_{seed2}.png'))
plt.show(block=True)
上述代码中没见过的是我自己写的函数,宗旨就是读取数据拼贴然后调用lineplot函数进行绘制。
大致效果如下: