plotroc.m

function out1 = plotroc(varargin)
%PLOTROC Plot receiver operating characteristic.
%
% <a href="matlab:doc plotroc">plotroc</a>(targets,outputs) takes target data in 1-of-N form (each column
% vector is all zeros with a single 1 indicating the class number), and
% output data and generates a receiver operating characteristic plot.
%
% The best classifications will show the receiver operating line hugging
% the left and top sides of the plots axis.
%
% <a href="matlab:doc plotroc">plotroc</a>(targets,1,outputs1,'name1',targets2,outputs2,names2,...)
% generates a variable number of confusion plots in one figure.
%
% Here a pattern recognition network is trained and its accuracy plotted:
%
%   [x,t] = <a href="matlab:doc simpleclass_dataset">simpleclass_dataset</a>;
%   net = <a href="matlab:doc patternnet">patternnet</a>(10);
%   net = <a href="matlab:doc train">train</a>(net,x,t);
%   y = net(x);
%   <a href="matlab:doc plotroc">plotroc</a>(t,y);
%
% See also roc, plotconfusion, ploterrhist, plotregression.

% Copyright 2010-2012 The MathWorks, Inc.

%% =======================================================
%  BOILERPLATE_START
%  This code is the same for all Transfer Functions.

  persistent INFO;
  if isempty(INFO), INFO = get_info; end
  if nargin == 0
    fig = nnplots.find_training_plot(mfilename);
    if nargout > 0
      out1 = fig;
    elseif ~isempty(fig)
      figure(fig);
    end
    return;
  end
  in1 = varargin{1};
  if ischar(in1)
    switch in1
      case 'info',
        out1 = INFO;
      case 'suitable'
        [args,param] = nnparam.extract_param(varargin,INFO.defaultParam);
        [net,tr,signals] = deal(args{2:end});
        update_args = standard_args(net,tr,signals);
        unsuitable = unsuitable_to_plot(param,update_args{:});
        if nargout > 0
          out1 = unsuitable;
        elseif ~isempty(unsuitable)
          for i=1:length(unsuitable)
            disp(unsuitable{i});
          end
        end
      case 'training_suitable'
        [net,tr,signals,param] = deal(varargin{2:end});
        update_args = training_args(net,tr,signals,param);
        unsuitable = unsuitable_to_plot(param,update_args{:});
        if nargout > 0
          out1 = unsuitable;
        elseif ~isempty(unsuitable)
          for i=1:length(unsuitable)
            disp(unsuitable{i});
          end
        end
      case 'training'
        [net,tr,signals,param] = deal(varargin{2:end});
        update_args = training_args(net,tr,signals);
        fig = nnplots.find_training_plot(mfilename);
        if isempty(fig)
          fig = figure('visible','off','tag',['TRAINING_' upper(mfilename)]);
          plotData = setup_figure(fig,INFO,true);
        else
          plotData = get(fig,'userdata');
        end
        set_busy(fig);
        unsuitable = unsuitable_to_plot(param,update_args{:});
        if isempty(unsuitable)
          set(0,'CurrentFigure',fig);
          plotData = update_plot(param,fig,plotData,update_args{:});
          update_training_title(fig,INFO,tr)
          nnplots.enable_plot(plotData);
        else
          nnplots.disable_plot(plotData,unsuitable);
        end
        fig = unset_busy(fig,plotData);
        if nargout > 0, out1 = fig; end
      case 'close_request'
        fig = nnplots.find_training_plot(mfilename);
        if ~isempty(fig),close_request(fig); end
      case 'check_param'
        out1 = ''; % TODO
      otherwise,
        try
          out1 = eval(['INFO.' in1]);
        catch me, nnerr.throw(['Unrecognized first argument: ''' in1 ''''])
        end
    end
  else
    [args,param] = nnparam.extract_param(varargin,INFO.defaultParam);
    update_args = standard_args(args{:});
    if ischar(update_args)
      nnerr.throw(update_args);
    end
    [plotData,fig] = setup_figure([],INFO,false);
    unsuitable = unsuitable_to_plot(param,update_args{:});
    if isempty(unsuitable)
      plotData = update_plot(param,fig,plotData,update_args{:});
      nnplots.enable_plot(plotData);
    else
      nnplots.disable_plot(plotData,unsuitable);
    end
    set(fig,'visible','on');
    drawnow;
    if nargout > 0, out1 = fig; end
  end
end

function set_busy(fig)
  set(fig,'userdata','BUSY');
end

function close_request(fig)
  ud = get(fig,'userdata');
  if ischar(ud)
    set(fig,'userdata','CLOSE');
  else
    delete(fig);
  end
  drawnow;
end

function fig = unset_busy(fig,plotData)
  ud = get(fig,'userdata');
  if ischar(ud) && strcmp(ud,'CLOSE')
    delete(fig);
    fig = [];
  else
    set(fig,'userdata',plotData);
  end
  drawnow;
end

function tag = new_tag
  tagnum = 1;
  while true
    tag = [upper(mfilename) num2str(tagnum)];
    fig = nnplots.find_plot(tag);
    if isempty(fig), return; end
    tagnum = tagnum+1;
  end
end

function [plotData,fig] = setup_figure(fig,info,isTraining)
  PTFS = nnplots.title_font_size;
  if isempty(fig)
    fig = get(0,'CurrentFigure');
    if isempty(fig) || strcmp(get(fig,'nextplot'),'new')
      if isTraining
        tag = ['TRAINING_' upper(mfilename)];
      else
        tag = new_tag;
      end
      fig = figure('visible','off','tag',tag);
      if isTraining
        set(fig,'CloseRequestFcn',[mfilename '(''close_request'')']);
      end
    else
      clf(fig);
      set(fig,'tag','');
      set(fig,'tag',new_tag);
    end
  end
  set(0,'CurrentFigure',fig);
  ws = warning('off','MATLAB:Figure:SetPosition');
  plotData = setup_plot(fig);
  warning(ws);
  if isTraining
    set(fig,'nextplot','new');
    update_training_title(fig,info,[]);
  else
    set(fig,'nextplot','replace');
    set(fig,'name',[info.name ' (' mfilename ')']);
  end
  set(fig,'NumberTitle','off','toolbar','none');
  plotData.CONTROL.text = uicontrol('parent',fig,'style','text',...
    'units','normalized','position',[0 0 1 1],'fontsize',PTFS,...
    'fontweight','bold','foreground',[0.7 0 0]);
  set(fig,'userdata',plotData);
end

function update_training_title(fig,info,tr)
  if isempty(tr)
    epochs = '0';
    stop = '';
  else
    epochs = num2str(tr.num_epochs);
    if isempty(tr.stop)
      stop = '';
    else
      stop = [', ' tr.stop];
    end
  end
  set(fig,'name',['Neural Network Training ' ...
    info.name ' (' mfilename '), Epoch ' epochs stop]);
end

%  BOILERPLATE_END
%% =======================================================

function info = get_info
  info = nnfcnPlot(mfilename,'Receiver Operating Characteristic',7.0,[]);
end

function args = training_args(net,tr,data)
  yall  = nncalc.y(net,data.X,data.Xi,data.Ai);
  y = {yall};
  t = {gmultiply(data.train.mask,data.T)};
  names = {'Training'};
  if ~isempty(data.val.enabled)
    y = [y {yall}];
    t = [t {gmultiply(data.val.mask,data.T)}];
    names = [names {'Validation'}];
  end
  if ~isempty(data.test.enabled)
    y = [y {yall}];
    t = [t {gmultiply(data.test.mask,data.T)}];
    names = [names {'Test'}];
  end
  if length(t) >= 2
    t = [t {data.T}];
    y = [y {yall}];
    names = [names {'All'}];
  end
  args = {t y names};
end

function args = standard_args(varargin)
  if nargin < 2
    args = 'Not enough input arguments.';
  elseif (nargin > 2) && (rem(nargin,3) ~= 0)
    args = 'Incorrect number of input arguments.';
  elseif nargin == 2
    % (t,y)
    t = { nntype.data('format',varargin{1}) };
    y = { nntype.data('format',varargin{2}) };
    names = {''};
    args = {t y names};
  else
    % (t1,y1,name1,...)
    % TODO - Check data is consistent
    count = nargin/3;
    t = cell(1,count);
    y = cell(1,count);
    names = cell(1,count);
    for i=1:count
      t{i} = nntype.data('format',varargin{i*3-2});
      y{i} = nntype.data('format',varargin{i*3-1});
      names{i} = varargin{i*3};
    end
    param.outputIndex = 1;
    args = {t y names};
  end
end

function plotData = setup_plot(fig)
  plotData.numSignals = 0;
end

function fail = unsuitable_to_plot(param,t,y,names)
  fail = '';
  t1 = t{1};
  if numsamples(t1) == 0
    fail = 'The target data has no samples to plot.';
  elseif numtimesteps(t1) == 0
    fail = 'The target data has no timesteps to plot.';
  elseif sum(numelements(t1)) == 0
    fail = 'The target data has no elements to plot.';
  end
end

function plotData = update_plot(param,fig,plotData,tt,yy,names)

  t = tt{1};
  numSignals = length(names);
  numClasses = size(t{1},1);

  % Rebuild figure
  if (plotData.numSignals ~= numSignals) || (plotData.numClasses ~= numClasses)
    set(fig,'nextplot','replace');
    plotData.numSignals = numSignals;
    plotData.numClasses = numClasses;
    plotData.axes = zeros(1,numSignals);
    colors = nncolor.ncolors(numClasses);
    plotcols = ceil(sqrt(numSignals));
    plotrows = ceil(numSignals/plotcols);
    for plotrow=1:plotrows
      for plotcol=1:plotcols
        i = (plotrow-1)*plotcols+plotcol;
        if (i<=numSignals)
          a = subplot(plotrows,plotcols,i);
          cla(a)
          set(a,'dataaspectratio',[1 1 1]);
          set(a,'xlim',[0 1]);
          set(a,'ylim',[0 1]);
          hold on
          axisdata = [];
          axisdata.lines = zeros(1,numClasses);
          for j=1:numClasses
            c = colors(j,:);
            line([0 1],[0 1],'linewidth',2,'color',[1 1 1]*0.8);
            axisdata.lines(j) = line([0 1],[0 1],'linewidth',2,'Color',c);
          end
          if ~isempty(names{1})
            titleStr = [names{i} ' ROC'];
          else
            titleStr = 'ROC';
          end
          title(a,titleStr);
          xlabel(a,'False Positive Rate');
          ylabel(a,'True Positive Rate');
          plotData.axes(i) = a;
          set(a,'userdata',axisdata);
          if (i==1) && (numClasses > 1)
            labels = cell(1,numClasses);
            for ii=1:numClasses, labels{ii} = ['Class ' num2str(ii)]; end
            legend(axisdata.lines,labels{:})
          end
        end
      end
    end
    screenSize = get(0,'ScreenSize');
    screenSize = screenSize(3:4);
    windowSize = 700 * [1 (plotrows/plotcols)];
    pos = [(screenSize-windowSize)/2 windowSize];
    set(fig,'position',pos);
  end

  % Update details
  for i=1:numSignals
    y = yy{i}; if iscell(y), y = cell2mat(y); end
    t = tt{i}; if iscell(t), t = cell2mat(t); end
    [tpr,fpr] = roc(t,y);
    if ~iscell(tpr)
      tpr = {tpr};
      fpr = {fpr};
    end
    a = plotData.axes(i);
    axisdata = get(a,'userdata');
    for j=1:numClasses
      set(axisdata.lines(j),'xdata',[0 fpr{j} 1],'ydata',[0 tpr{j} 1]);
    end
  end
end

  

posted @ 2018-07-23 22:07  西瓜刀刀刀  阅读(446)  评论(0编辑  收藏  举报