使用TuShare+plotly实现蜡烛K线图的绘制

1,084 阅读3分钟

使用TuShare+plotly实现蜡烛K线图的绘制

个人Id: 521333

对于股票分析来说,K线图是最基础的元素,本文的目标是使用TuShare提供的股票数据加上plotly绘制蜡烛K线图,在这个过程中比较大的坑就是直接使用plotly的API绘制蜡烛K线图会出现K线图不连续的情况,因为plotly的API不会自动处理交易日不连续的情况。

除此之外Tushare官方提供了python版的SDK,可以直接使用。

初始化Tushare的pro接口

首先需要初始化Tushare的Pro接口,此处填入你在tushare的token接口即可,tushare地址:tushare数据
token接口需要自己去申请。

pro = ts.pro_api('此处填入你的tushare的token接口呀~')

获取日线行情

初始化Pro接口之后,可直接使用daily接口获取日线行情数据,对于daily接口,返回如下参数数据:

  • ts_code:股票代码
  • trade_date:交易日期
  • open:开盘价
  • high:最高价
  • low:最低价
  • close:收盘价
  • pre_close:昨收价
  • change:涨跌额
  • pct_chg:涨跌幅
  • vol:成交量 (手)
  • amount:成交额 (千元)

在获取数据之后,需要将trade_date转换成日期时间对象,然后按照日期进行升序排列

ts_code = '002594.SZ'
start_date = '20220601'
end_date = '20220620'
df = pro.daily(ts_code=ts_code, start_date=start_date, end_date=end_date)
# 将trade_date转换为日期时间对象
df['trade_date'] = pd.to_datetime(df['trade_date'])
# 按日期升序
df = df.sort_values(by='trade_date')

获取交易日历

为了解决K线图不连续的问题,还需要拿到交易日历,通过交易日历数据获取交易日。

在Tushare的SDK中,trade_cal接口可以获取交易日历数据,exchange可传入交易所代码(可以为空),start_date是交易日历数据的开始日期,end_date是交易日历数据的截止日期。返回如下参数数据:

  • exchange:交易所
  • cal_dat:日历日期
  • is_open:是否交易 0 休市 1 交易
  • pretrade_date:上一个交易日

因此可以通过is_open来判断今天是不是交易日。

def get_trading_date(_start_date, _end_date):
    trading_date_list = pro.trade_cal(exchange='', start_date=_start_date, end_date=_end_date).values.tolist()
    trading_dates = []
    for i in range(len(trading_date_list)):
        if trading_date_list[i][2] == 1:
            cur_date = trading_date_list[i][1]
            trading_dates.append(datetime.datetime.strptime(cur_date, '%Y%m%d'))
    return trading_dates

在获取交易日之后,可以获取不交易的日期:

dt_all = pd.date_range(start=df['trade_date'].iloc[0], end=df['trade_date'].iloc[-1])
dt_breaks = list(set(dt_all) - set(trade_date))

K线绘制

def plot_cand_volume(data, dt_breaks):

    fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
                        vertical_spacing=0.03, subplot_titles=('', '成交量'),
                        row_width=[0.2, 0.7])
    fig.add_trace(go.Candlestick(x=data['trade_date'],
                                 open=data['open'],
                                 high=data['high'],
                                 low=data['low'],
                                 close=data['close'],
                                 increasing_line_color= '#FF0033',
                                 decreasing_line_color= '#009966'),
                  row=1, col=1
                  )
    fig.add_trace(go.Bar(x=data['trade_date'], y=data['vol'], showlegend=False), row=2, col=1)

    fig.update_xaxes(
        title_text='date',
        rangeslider_visible=True, # 下方滑动条缩放
        rangeselector=dict(
            # 增加固定范围选择
            buttons=list([
                dict(count=1, label='1M', step='month', stepmode='backward'),
                dict(count=6, label='6M', step='month', stepmode='backward'),
                dict(count=1, label='1Y', step='year', stepmode='backward'),
                dict(count=1, label='YTD', step='year', stepmode='todate'),
                dict(step='all')])))
    # Do not show OHLC's rangeslider plot
    fig.update(layout_xaxis_rangeslider_visible=False)

    # 去除休市的日期,保持连续
    fig.update_xaxes(rangebreaks=[dict(values=dt_breaks)])
    return fig

最终效果

结果数据.jpg

完整代码

import tushare as ts
import datetime
import pandas as pd

import plotly.graph_objects as go
from plotly.subplots import make_subplots

pro = ts.pro_api('此处填入你的tushare的token接口呀~')


def get_trading_date(_start_date, _end_date):
    trading_date_list = pro.trade_cal(exchange='', start_date=_start_date, end_date=_end_date).values.tolist()
    trading_dates = []
    for i in range(len(trading_date_list)):
        if trading_date_list[i][2] == 1:
            cur_date = trading_date_list[i][1]
            trading_dates.append(datetime.datetime.strptime(cur_date, '%Y%m%d'))
    return trading_dates


def plot_cand_volume(data, dt_breaks):

    fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
                        vertical_spacing=0.03, subplot_titles=('', '成交量'),
                        row_width=[0.2, 0.7])
    fig.add_trace(go.Candlestick(x=data['trade_date'],
                                 open=data['open'],
                                 high=data['high'],
                                 low=data['low'],
                                 close=data['close'],
                                 increasing_line_color= '#FF0033',
                                 decreasing_line_color= '#009966'),
                  row=1, col=1
                  )
    fig.add_trace(go.Bar(x=data['trade_date'], y=data['vol'], showlegend=False), row=2, col=1)

    fig.update_xaxes(
        title_text='date',
        rangeslider_visible=True, # 下方滑动条缩放
        rangeselector=dict(
            # 增加固定范围选择
            buttons=list([
                dict(count=1, label='1M', step='month', stepmode='backward'),
                dict(count=6, label='6M', step='month', stepmode='backward'),
                dict(count=1, label='1Y', step='year', stepmode='backward'),
                dict(count=1, label='YTD', step='year', stepmode='todate'),
                dict(step='all')])))
    # Do not show OHLC's rangeslider plot
    fig.update(layout_xaxis_rangeslider_visible=False)

    # 去除休市的日期,保持连续
    fig.update_xaxes(rangebreaks=[dict(values=dt_breaks)])
    return fig

if __name__ == '__main__':
    ts_code = '002594.SZ'
    start_date = '20220601'
    end_date = '20220620'
    trade_date = get_trading_date(start_date, end_date)
    df = pro.daily(ts_code=ts_code, start_date=start_date, end_date=end_date)
    # 将trade_date转换为日期时间对象
    df['trade_date'] = pd.to_datetime(df['trade_date'])
    # 按日期升序
    df = df.sort_values(by='trade_date')
    dt_all = pd.date_range(start=df['trade_date'].iloc[0], end=df['trade_date'].iloc[-1])
    dt_breaks = list(set(dt_all) - set(trade_date))
    # 绘制 复杂k线图
    fig = plot_cand_volume(df, dt_breaks)
    fig.show()