利用matplotlib画用于机器学习的K线图练手任务

前个阶段完成了利用matplotlib画用于机器学习的K线图练手任务,

代码如下:

#-*- utf-8 -*-
__author__='chen shaowu'


import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os

#全局变量:常量
X=14
Y=20
Z1=6
Z2=3

#常量修改函数
def set_value(x=14,y=20,z1=6,z2=3):
    global X,Y,Z1,Z2
    X,Y,Z1,Z2=x,y,z1,z2

    
#计算本数据段sigma
def count_sigma(data1,index):
    global X
    allf=0#所有F之和
    i=0#包含当天
    while i<=X:
        maxp=data1[index-i][3]
        minp=data1[index-i][4]
        cp=data1[index-i-1][1]#前一天的收盘价
        f=max(abs(maxp-minp),abs(maxp-cp),abs(cp-minp))/cp
        f=min(f,0.191)#if f>0.191:f=0.191
        allf+=f
        i+=1
    sigma=allf/X
    return sigma

#判断是不是大涨、跌
def judge1(data1,index,greaterate,lendata1):
    global Y
    i=1
    while i<=Y:#涨幅=今日收盘相对当前图最后一日收盘的涨幅
        
        if(index+i)==lendata1:
            
            return -3#当图后面的数据不足20个,不画这个图,直接over
        
        yc=data1[index][1]#close price of the last day in the plot
        tc=data1[index+i][1]#today close price
        growrate=(tc-yc)/yc
        if growrate>0:
            if growrate>=greaterate:
                
                return 2
                       
        elif growrate<0:
            if abs(growrate)>=greaterate:
                
                return -2
        i+=1      
            
    return 0


#无大涨、跌时判断有无小涨、跌    
def judge2(data1,index,smallrate,lendata1):
    global Y
    i=1
    while i<=Y:
        
        if(index+i)==lendata1:
            
            return -3#当图后面的数据不足20个,不画这个图,直接over
        
        yc=data1[index][1]#close price of the last day of the plot
        tc=data1[index+i][1]#today close price
        growrate=(tc-yc)/yc
        if growrate>0:
            if growrate>=smallrate:
               
                return 1
                       
        elif growrate<0:
            if abs(growrate)>=smallrate:
                
                return -1
        i+=1
               
    return 0
    
    
    
#分类函数
 #mark=0#{0:平;-1;小跌;-2:大跌;1:小涨;2:大涨;-3:数据不足}       
def judge(data1,index,lendata1):#'index' is the index of final day
    global Z1,Z2
    if index==lendata1-1:return -3#当天为最后一个数据,让程序直接over啦 
    
    sigma=count_sigma(data1,index)    
    greaterate=sigma*Z1
    smallrate=sigma*Z2

    mark=judge1(data1,index,greaterate,lendata1)#首先判断有无大涨跌
    if mark==-3:return -3#当图后面的数据不足20个,不画这个图,直接over
    elif abs(mark)==2:return mark             #大,over
    elif mark==0:                             
        mark=judge2(data1,index,smallrate,lendata1)  #无大,判断有无小        
        return mark                      ###返回:1/0/-1


def scaler(data1,x,y,N):
    
    data = data1[x:y]
    sigma = count_sigma(data1,y-1)
    min0 = min(x[4] for x in data)
    max0 = max(x[3] for x in data)

    sum0=0
    for i in data:
        sum0+=(i[3]+i[4])/2
    m=sum0/60
    
    return [min(m-N*sigma,min0),max(m+N*sigma,max0)]


#找到伸缩因子N
def find_N(data1,lim=0.85):

    N = -1
    count = 0
    num_pic = int((len(data1)-60)/5)#图片数量
    
    while count/num_pic<lim:
        N+=1

        #遍历所有数据
        x = 0
        y = 60
        while y<len(data1):

            data = data1[x:y]#当前数据
            sigma = count_sigma(data1,y-1)
            
            #中心价格:
            sum0=0
            for i in data:
                sum0+=(i[3]+i[4])/2
            m=sum0/60

            min0 = min(x[4] for x in data)
            max0 = max(x[3] for x in data)

            ymin = m-N*sigma
            ymax = m+N*sigma

            if ymin<=min0 and ymax>=max0:
                count+=1

            x+=5
            y+=5
    print(N)
    return N


def scaler(data1,x,y,N):
    
    data = data1[x:y]
    sigma = count_sigma(data1,y-1)
    min0 = min(x[4] for x in data)
    max0 = max(x[3] for x in data)

    sum0=0
    for i in data:
        sum0+=(i[3]+i[4])/2
    m=sum0/60
    
    return [min(m-N*sigma,min0),max(m+N*sigma,max0)]
    
               
   



'''def scaler(data1,x,y,index):
    global X
    N=3
    data=data1[x:y]
    
    sigma=count_sigma(data1,index)
   
    min0=min(x[4] for x in data)
    max0=max(x[3] for x in data)

    sum0=0
    for i in data:
        sum0+=(i[3]+i[4])/2
    m=sum0/60

    ymin=min(m-N*sigma,min0)
    ymax=max(m+N*sigma,max0)

    return [min0,max0,ymin,ymax]#只能出现两种结果:1、刚刚好铺满;2、上层或下层留白'''

   

#主函数
def main():
    #转到数据所在目录
    os.chdir('C:\\Users\Steve\Desktop\python程序\量化.K线')
    #打开数据
    with open('股票数据.csv') as file:
        dataC = []
        dataD = []
        dataS = []
        lines=(x for x in file.readlines())#generator;save the memory
        for row in lines:
            f=row.strip('\n').strip(',').split(',')
            if len(f)==6 or len(f)==11 or len(f)==16 and ('日期' not in f):
                temp=[f[0]]
                temp.extend([float(x) for x in f[1:6]])
                dataC.append(temp)#data\cohl\changerate
                
            if len(f)==11 or len(f)==16 and ('日期' not in f):
                temp=[f[0]]
                temp.extend([float(x) for x in f[6:11]])
                dataD.append(temp)#data\cohl\changerate            
            if len(f)==16 and ('日期' not in f):
                temp=[f[0]]
                temp.extend([float(x) for x in f[11:16]])
                dataS.append(temp)#data\cohl\changerate


              
    dic={0:'ChinaBanktest',1:'DaHuatest',2:'ShangRongtest'}
    j=0
    while j<3:
        if j==0:data1=dataC
        elif j==1:data1=dataD
        elif j==2:data1=dataS

        N=find_N(data1) #y轴伸缩因子

        x=0
        y=60
        k=0
        while y<=len(data1):
            if data1[y-1][5]==0 or len([x[5] for x in data1[x:y] if x[5]==0])>len(data1[x:y])*2/3:
                pass
            else:
                fig,ax=plt.subplots(figsize=(64,64))
                ax.fill_between(range(60),np.clip([a[5] for a in data1[x:y]],0,15),0,color='blue')
                plt.ylim(0,15)
                plt.xlim(0.5,59.5)#配合下面完成横向2 60 2 pixel!
                plt.axis('off')

                ax2=ax.twinx()
                
                ymin,ymax = scaler(data1,x,y,N)                
                
                i=0
                for a in data1[x:y]:
                    c,o,h,l=a[1:5]

                    if c>=o:#yang
                        low=o
                        high=c
                        color1='#800000'
                        color2='#ff0000'
                    else:#ying
                        low=c
                        high=o
                        color1='#008000'
                        color2='#00ff00'

                    ax2.bar(i,h-l,bottom=l,width=1,color=color1)
                    ax2.bar(i,high-low,bottom=low,width=1,color=color2)

                    if c==o==h==l:
                        ax2.bar(i,(ymax-ymin)/60,bottom=l,width=1,color=color1)
                    i+=1
                   
                plt.axis('off')                              
                plt.ylim((ymin,ymax))                
                plt.subplots_adjust(right=62/64, left=2/64,bottom=2/64,top=62/64)#纵向2 60 2pixel在这完成!
              
    
                os.chdir('C:\\Users\\Steve\\Desktop\\python程序\\量化.K线\\figure')
                dpi = 1     
                mark = judge(data1,y-1,len(data1))#注意这里的减一,使得index为其在数据中的确切index;len(data1)用来处理后面长度不足Y的数据
                
                if mark==-3:break#当图后面的数据不足,不画这个图,直接over
                if mark==2:
                    folder='%s\\++'%dic[j]
                    os.makedirs(folder,exist_ok=True)
                    os.chdir(folder)
                    plt.savefig('%d%s.png'%(k,'++'),format='png',dpi=dpi)
                if mark==1:
                    folder='%s\\+'%dic[j]
                    os.makedirs(folder,exist_ok=True)
                    os.chdir(folder)
                    plt.savefig('%d%s.png'%(k,'+'),format='png',dpi=dpi)
                if mark==0:
                    folder='%s\\flat'%dic[j]
                    os.makedirs(folder,exist_ok=True)
                    os.chdir(folder)
                    plt.savefig('%d%s.png'%(k,'flat'),format='png',dpi=dpi)
                if mark==-1:
                    folder='%s\\-'%dic[j]
                    os.makedirs(folder,exist_ok=True)
                    os.chdir(folder)
                    plt.savefig('%d%s.png'%(k,'-'),format='png',dpi=dpi)
                if mark==-2:
                    folder='%s\\--'%dic[j]
                    os.makedirs(folder,exist_ok=True)
                    os.chdir(folder)
                    plt.savefig('%d%s.png'%(k,'--'),format='png',dpi=dpi)

                
                plt.close()
          
            x+=5
            y+=5
            k+=1
        j+=1

if __name__=='__main__':
    main()
    

'''疑问:
sclaer:min\max做y的底或顶?那么要N何用?
'''

 

posted @ 2017-11-30 15:24  笨鸟不走  阅读(1028)  评论(0编辑  收藏  举报