在matplotlib中显示树图的变化

0 投票
1 回答
75 浏览
提问于 2025-04-14 18:11

我想制作这个:

使用matplotlib制作的树图

这个图表的数据是:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


data = {
    "year": [2004, 2022, 2004, 2022, 2004, 2022],
    "countries" : [ "Denmark", "Denmark", "Norway", "Norway","Sweden", "Sweden",],
    "sites": [4,10,5,8,13,15]
}
df= pd.DataFrame(data)
df['diff'] = df.groupby(['countries'])['sites'].diff()
df['diff'].fillna(df.sites, inplace=True)

df

我知道有一些库可以制作树图,比如squarify和plotly,但我还没搞明白怎么做上面这个图,特别是年份的值是如何相加的(或者说准确点是差值)。如果能用纯matplotlib来实现,那就太好了,前提是这不是太复杂。

有没有人能给点建议?我在谷歌上没找到很多关于树图的信息。

1 个回答

3

这个任务分为两个部分。

  1. 计算矩形的布局。
  2. 绘制矩形。

第一部分可能会比较复杂:有人专门为这个主题发表了科学论文。在这里重新发明轮子并不明智。不过,第二部分就简单多了,可以用matplotlib来完成。

下面的解决方案使用了squarify来计算布局,利用每对数值中的较大值,然后用matplotlib绘制两个重叠的矩形。

在这里输入图片描述

import numpy as np
import matplotlib.pyplot as plt
import squarify

from matplotlib import colormaps
from matplotlib.colors import to_rgba

DEFAULT_COLORS = list(zip(colormaps["tab20"].colors[::2],
                          colormaps["tab20"].colors[1::2]))


def color_to_grayscale(color):
    # Adapted from: https://stackoverflow.com/a/689547/2912349
    r, g, b, a = to_rgba(color)
    return (0.299 * r + 0.587 * g + 0.114 * b) * a


class PairedTreeMap:

    def __init__(self, values, colors=DEFAULT_COLORS, labels=None, ax=None, bbox=(0, 0, 200, 100)):
        """
        Draw a treemap of value pairs.

        values : list[tuple[float, float]]
            A list of value pairs.

        colors : list[tuple[RGBA, RGBA]]
            The corresponding color pairs. Defaults to light/dark tab20 matplotlib color pairs.

        labels : list[str]
            The labels, one for each pair.

        ax : matplotlib.axes._axes.Axes
            The matplotlib axis instance to draw on.

        bbox : tuple[float, float, float, float]
            The (x, y) origin and (width, height) extent of the treemap.

        """

        self.ax = self.initialize_axis(ax)
        self.rects = self.get_layout(values, bbox)
        self.artists = list(self.draw(self.rects, values, colors, self.ax))

        if labels:
            self.labels = list(self.add_labels(self.rects, labels, values, colors, self.ax))


    def get_layout(self, values, bbox):
        maxima = np.max(values, axis=1)
        order = np.argsort(maxima)[::-1]
        normalized_maxima = squarify.normalize_sizes(maxima[order], *bbox[2:])
        rects = squarify.padded_squarify(normalized_maxima, *bbox)
        reorder = np.argsort(order)
        return [rects[ii] for ii in reorder]


    def initialize_axis(self, ax=None):
        if ax is None:
            fig, ax = plt.subplots()
        ax.set_aspect("equal")
        ax.axis("off")
        return ax


    def _get_artist_pair(self, rect, value_pair, color_pair):
        x, y, w, h = rect["x"], rect["y"], rect["dx"], rect["dy"]
        (small, large), (color_small, color_large) = zip(*sorted(zip(value_pair, color_pair)))
        ratio = np.sqrt(small / large)
        return (plt.Rectangle((x, y), w,         h,         color=color_large, zorder=1),
                plt.Rectangle((x, y), w * ratio, h * ratio, color=color_small, zorder=2))


    def draw(self, rects, values, colors, ax):
        for rect, value_pair, color_pair in zip(rects, values, colors):
            large_patch, small_patch = self._get_artist_pair(rect, value_pair, color_pair)
            ax.add_patch(large_patch)
            ax.add_patch(small_patch)
            yield(large_patch, small_patch)
        ax.autoscale_view()


    def add_labels(self, rects, labels, values, colors, ax):
        for rect, label, value_pair, color_pair in zip(rects, labels, values, colors):
            x, y, w, h = rect["x"], rect["y"], rect["dx"], rect["dy"]
            # decide a fontcolor based on background brightness
            (small, large), (color_small, color_large) = zip(*sorted(zip(value_pair, color_pair)))
            ratio = small / large
            background_brightness = color_to_grayscale(color_large) if ratio < 0.33 else color_to_grayscale(color_small) # i.e. 0.25 + some fudge
            fontcolor = "white" if background_brightness < 0.5 else "black"
            yield ax.text(x + w/2, y + h/2, label, va="center", ha="center", color=fontcolor)


if __name__ == "__main__":

    values = [
        (4, 10),
        (13, 15),
        (5, 8),
    ]

    colors = [
        ("red", "coral"),
        ("royalblue", "cornflowerblue"),
        ("darkslategrey", "gray"),
    ]

    labels = [
        "Denmark",
        "Sweden",
        "Norway"
    ]

    PairedTreeMap(values, colors=colors, labels=labels, bbox=(0, 0, 100, 100))
    plt.show()


撰写回答