• 博客园logo
  • 会员
  • 周边
  • 众包
  • 新闻
  • 博问
  • 闪存
  • 赞助商
  • Chat2DB
    • 搜索
      所有博客
    • 搜索
      当前博客
  • 写随笔 我的博客 短消息 简洁模式
    用户头像
    我的博客 我的园子 账号设置 会员中心 简洁模式 ... 退出登录
    注册 登录
华东 博客
17年国科大博士毕业,曾就职于三星电子,清华博后,目前在某大模型创业公司工作,研究方向大模型、智能体 新浪博客: http://blog.sina.com.cn/u/2463286753
博客园    首页    新随笔    联系   管理    订阅  订阅
Matplotlib的subplot画图, 共享colorbar

Python画图,利用Matplotlib中subplot画3*3的heatmap图,所有热力图共享一个colorbar。

import numpy as np
import matplotlib
matplotlib.use('AGG')
import matplotlib.pyplot as plt
import csv
import pandas as pd  

fig, axs = plt.subplots(3, 3, sharex=True, sharey=True, constrained_layout=True, figsize=(6, 6))
fig.subplots_adjust(hspace=0.1)
datasets = ['envi', 'enfr','enge']
types = ['enc', 'enc-dec' ,'dec']
Datasets = ['En-Vi', 'En-Fr','En-Ge']

def draw_map(ax,xlabels, ylabels, tt, attns, index):   #画热力图
    im = ax.imshow(attns,interpolation='none',  cmap='Blues', vmin=0, vmax=1)
    ax.set_title(tt, fontsize=10)
style='italic' )
    return im
xlabels = '0 1 2 3 4 5'.split()
ylabels = '5 4 3 2 1 0'.split()

for i, data in enumerate(datasets):
    datafile = '{}.csv'.format(data) #读取数据
    data = pd.read_csv(datafile, delimiter='\t', header=None)
    for j, tt in enumerate(types):
        A = data.values[:,j*6:(j+1)*6] #取出所需热力图的数据
        ax = axs[i, j]
        im = draw_map(ax, xlabels,ylabels, '%s (%s)'%(Datasets[i],tt), A, j)

                
plt.xticks(range(len(xlabels)),xlabels)    #设置横坐标label
plt.yticks(range(len(ylabels)),ylabels)    #设置纵坐标label

cb_ax = fig.add_axes([1.0, 0.1, 0.02, 0.8]) #设置colarbar位置
cbar = fig.colorbar(im, cax=cb_ax)     #共享colorbar
plt.tight_layout()
plt.show()
fig.savefig('weights_visualization.pdf', bbox_inches='tight', dpi=500)

 

画图结果:

 

posted on 2021-01-22 23:19  华东博客  阅读(5649)  评论(0)    收藏  举报
刷新页面返回顶部
博客园  ©  2004-2025
浙公网安备 33010602011771号 浙ICP备2021040463号-3