matplotlib numpy 绘制心肌模型

110 阅读2分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第2天,点击查看活动详情

matplotlib 官网案例

image.png

数据结构

// segment_params
[
    [
        [],
        ...
    ]
    [
        [],
        ...
    ]
    []
]

mean_data
[] // 长度16

导入依赖

import os
import io
import json
import numpy as np
import base64
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
from matplotlib.colors import LinearSegmentedColormap
# 给前端接口需要添加表示不使用gui,不是接口可以注释直接 show
plt.switch_backend('agg')

制图

def plot_bullseye(segment_params=[],
                  mean_data=[],
                  img_path='',
                  max=100,
                  min=0):
    
    # 设置画布
    fig, ax = plt.subplots(figsize=(4, 4), # 画布大小
                           nrows=1,
                           ncols=1,
                           subplot_kw=dict(projection='polar') # 类型)
                    
    # 自定义colorBar
    cdict = ['#0014aa', '#00fffa', '#01fe08', '#fcec01', '#fa1e01']
    cmap = LinearSegmentedColormap.from_list('chaos', cdict, N=256)
    norm = mpl.colors.Normalize(
        vmin=min, vmax=max)
    sm = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])

    # 获取数据处理数据
    np_array = np.array(segment_params)
    mean_data = np.array(mean_data)

    ax.grid(False)

    # 设置坏点 当前是 -1
    np_array = np.ma.masked_where(np_array == -1, np_array)
    cmap.set_bad(color='black', alpha=0)
    bullseye_plot(ax, np_array, cmap=cmap,
                  norm=norm, mean_data=mean_data)

    format_type = 'png'
    
    img_data = None
    # 是否要保存数据
    if not img_path:
        img_save = io.BytesIO()
        plt.savefig(img_save, transparent=True,
                    bbox_inches='tight', format=format_type)
        img_save.seek(0)
        img_data = 'data:image/'\
            + format_type\
            + ';base64,'\
            + str(base64.b64encode(img_save.getvalue()),
                  encoding='utf-8')
    else:
        out_path = os.path.join(
            folder, img_path + '.' + format_type)
        plt.savefig(out_path, transparent=True,
                    bbox_inches='tight', format=format_type)
    # 或者注释上面的,直接 plt.show() 可以直接看到结果
    plt.close()
    return img_data

绘制方法

def bullseye_plot(ax,
                  data,
                  seg_bold=None,
                  cmap=None,
                  norm=None,
                  mean_data=None,
                  num_sector=768):
    assert len(data) == 3

    linewidth = 1.5

    if cmap is None:
        cmap = plt.cm.jet

    if seg_bold is None:
        seg_bold = []

    if norm is None:
        norm = mpl.colors.Normalize(vmin=data.min(), vmax=data.max())

    r = np.linspace(0.2, 1, 4)
    inside_angle = 45
    outside_angle = 60

    # 循环大圈
    for idx, circle in enumerate(data):
        num_ring, num_sector = circle.shape
        theta = np.linspace(0, 2 * np.pi, num_sector+1)

        # Fill the segments
        radius = r[3 - idx]
        radius_next = r[2 - idx]
        space_r = (radius - radius_next) / num_ring

        # 外两圈60 内圈45 开始绘制
        origin_angle = inside_angle if idx == 2 else outside_angle
        # 循环小圈
        for i in range(num_ring):
            e_r = radius - i * space_r
            s_r = radius - (i + 1) * space_r
            r0 = np.array([s_r, e_r])
            r0 = np.repeat(r0[:, np.newaxis], num_sector + 1, axis=1).T

            theta0 = theta + np.deg2rad(origin_angle)  # segment1 start at 60
            theta0 = np.repeat(theta0[:, np.newaxis], 2, axis=1)

            z0 = circle[i]
            z0 = z0[:, np.newaxis]

            ax.pcolormesh(theta0, r0, z0, cmap=cmap, norm=norm)

    # Create the bound for the segment 17
    for i in range(r.shape[0]):
        ax.plot(theta, np.repeat(r[i], theta.shape), 'gray', lw=linewidth)

    # Create the bounds for the segments 1-12
    # 计算当前 index
    count = 0
    for i in range(3):

        if i == 2:
            origin_angle = inside_angle
            num = 4
        else:
            origin_angle = outside_angle
            num = 6

        unit = 360 / num
        for key in range(num):
            now_theta = np.deg2rad(key * unit + origin_angle)
            next_theta = np.deg2rad((key + 1) * unit + origin_angle)
            now_radius = r[3 - i]
            next_radius = r[2 - i]

            ax.plot([now_theta, now_theta], [
                    now_radius, next_radius], 'gray', lw=linewidth)

            # 判断是否展示 数值还是 第几个
            text = (count + 1) if not mean_data \
                else round(mean_data[count], 2)
            count = count + 1

            text_box = ax.text((now_theta + next_theta) / 2,
                               (r[3 - i] + r[2 - i]) / 2,
                               text,
                               color='white',
                               va="center",
                               ha='center',
                               fontsize=12)
            text_box.set_path_effects(
                [path_effects.Stroke(linewidth=linewidth, foreground='black'),
                    path_effects.Normal()])

    ax.set_ylim([0, 1])
    ax.set_yticklabels([])
    ax.set_xticklabels([])

最终结果 图比较大看着有些间隔,图缩小后就没有这种效果了

image.png