绘制双子图含图例完整代码

import matplotlib.pyplot as plt
from matplotlib import ticker
import seaborn as sns
import pandas as pd
import numpy as np

plt.style.use('ggplot')

# fig, ax = plt.subplots(1,len(scenarios), figsize=(25,4.5))
# plt.show()

# colors = ["#468de4", 'orange', 'green', '#e44a44', 'purple']
colors = ['orange', 'purple']

data = {"Ave_reward":[], "Step":[]}
reward = []

# name = './tikick/{}_{}.csv'.format(file_name[i], aname)
name = ['PongNoFrameskip-v0_0_bcq.csv','PongNoFrameskip-v0_0_bcq_0.csv']

# f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)

#************** below is the first figure ***************#

file = pd.read_csv(name[0])
for key in file:
if "avg_reward" in key:
reward.append(file[key].tolist())

plt_step = np.array(pd.read_csv(name[0])["step"].tolist())
plt_reward = np.array(reward)

data["Ave_reward"] = plt_reward.mean(0)
data["step"] = plt_step

ax = plt.subplot(121)


font_size = 10
x_step = plt_step.tolist()
y_mean = data["Ave_reward"].tolist()
y_max = plt_reward.max(0).tolist()
y_min = plt_reward.min(0).tolist()
 
# sns.lineplot(x_step, y_mean)
sns.lineplot(x_step, y_mean, ax=ax, legend=True, label = u'RCSCA_Net', lw=3, color=colors[0])
# plt.fill_between(x_step,y_max,y_min,facecolor="green",edgecolor="red",alpha=0.3)
num = plt.fill_between(x_step,y_max,y_min,alpha=0.2,label='bcq')


ax.set_title("BCQ", fontsize=font_size)

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xlabel('Time steps', fontdict={'size':20})
ax.tick_params(labelsize=font_size)
x_name = 'Average Return'
ax.set_ylabel(x_name, fontdict={'size':20})
 

plt.legend()
# fig.legend((num,), ('upper right',),'upper center',ncol=5)
# plt.legend(bbox_to_anchor=(-0.5, 1.40), ncol=2, labels='TiKick(ours)', \
# fontsize=22, frameon=True, facecolor='white', framealpha=1)

# ax.grid(True)

#************** below is the second figure ***************#

file = pd.read_csv(name[1])
for key in file:
if "avg_reward" in key:
reward.append(file[key].tolist())

plt_step = np.array(pd.read_csv(name[1])["step"].tolist())
plt_reward = np.array(reward)

data["Ave_reward"] = plt_reward.mean(0)
data["step"] = plt_step

ax = plt.subplot(121)


font_size = 10
x_step = plt_step.tolist()
y_mean = data["Ave_reward"].tolist()
y_max = plt_reward.max(0).tolist()
y_min = plt_reward.min(0).tolist()
 
# sns.lineplot(x_step, y_mean)
sns.lineplot(x_step, y_mean, ax=ax, legend=True, label = u'RCSCA_Net',lw=3, color=colors[1])
# plt.fill_between(x_step,y_max,y_min,facecolor="green",edgecolor="red",alpha=0.3)
num = plt.fill_between(x_step,y_max,y_min,alpha=0.2,label='bcq')


ax.set_title("bcq", fontsize=font_size)

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xlabel('Time steps', fontdict={'size':20})
ax.tick_params(labelsize=font_size)
x_name = 'Average Return'
# ax.set_ylabel(x_name, fontdict={'size':font_size})
 

# plt.legend()
# fig.legend((num,), ('upper right',),'upper center',ncol=5)
# plt.legend(bbox_to_anchor=(-0.5, 1.40), ncol=2, labels='TiKick(ours)', \
# fontsize=22, frameon=True, facecolor='white', framealpha=1)

ax.grid(True)

plt.show()

import pdb
pdb.set_trace()
 

 

 
 
posted @ 2022-02-27 19:48  呦呦南山  阅读(96)  评论(0编辑  收藏  举报