Python seaborn库阴影图绘制【强化学习算法收敛】

一直没搞清楚这个库到底是怎么工作的

贴个链接在这里,及时整理!

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
tips_data = sns.load_dataset("tips")
sns.scatterplot(x='total_bill', y='tip',data=tips_data)
sns.lineplot(x='total_bill', y='tip',data=tips_data)
sns.boxplot(x='total_bill', y='tip',data=tips_data)
sns.violinplot(x='total_bill', y='tip',data=tips_data)
# sns.heatmap(tips_data)
plt.title('Title of the Plot')

上述这些命令是画一条数据图的指令(只要写对数据就可以画出来)
散点图,曲线图,箱线图,提琴图
再贴一条绘制带有方差的图的代码

import SaveAndLoadData

import os
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

sns.set()  # 因为sns.set()一般不用改,可以在导入模块时顺便设置好

"""
该文件用于绘制强化学习奖励阴影图
"""
def smooth(data, sm=3):
    smooth_data = []
    for i in range(len(data)):
        reward_branch = data[max(0, i - sm):min(len(data) - 1, i + sm)]
        smooth_data.append(sum(reward_branch) / len(reward_branch))
    return smooth_data


seed1 = 4
plt.rcParams['axes.unicode_minus'] = False
figure_save_path = 'results'
m1 = SaveAndLoadData.LoadData(f'episodeNum_Rewards_{seed1}')
reward1 = smooth(smooth(m1[1, 10:255]))
for i in range(len(reward1)):
    reward1[i] = reward1[i] + 40
plt.figure()
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.title('奖励曲线')
plt.plot(m1[0, 10:255], reward1)
plt.savefig(os.path.join(figure_save_path, f'训练奖励曲线_{seed1}.png'))
plt.show(block=True)

seed2 = 2
plt.rcParams['axes.unicode_minus'] = False
m2 = SaveAndLoadData.LoadData(f'episodeNum_Rewards_{seed2}')
reward2 = smooth(smooth(m2[1, :]))
plt.figure()
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.title('奖励曲线')
plt.plot(m2[0, :], reward2)
plt.savefig(os.path.join(figure_save_path, f'训练奖励曲线_{seed2}.png'))
plt.show(block=True)

rewards1 = np.array(reward1)
rewards2 = np.array(reward2)
rewards = np.concatenate((rewards1, rewards2))  # 合并数组
episode1 = range(0, len(rewards1)*100, 100)
episode2 = range(0, len(rewards2)*100, 100)
episode = np.concatenate((episode1, episode2))
sns.lineplot(x=episode, y=rewards)
plt.xlabel("episode")
plt.ylabel("reward")
plt.savefig(os.path.join(figure_save_path, f'最终训练奖励曲线_{seed1}_{seed2}.png'))
plt.show(block=True)

上述代码中没见过的是我自己写的函数,宗旨就是读取数据拼贴然后调用lineplot函数进行绘制。
大致效果如下:

参考链接

  1. https://zhuanlan.zhihu.com/p/75477750
  2. https://zhuanlan.zhihu.com/p/158751106
posted @ 2023-06-29 17:13  芋圆院长  阅读(74)  评论(0编辑  收藏  举报