绘制一种热图,其颜色由函数x,y->r,g,b决定

2 投票
1 回答
958 浏览
提问于 2025-04-18 07:17

我有一个字典,它把XY坐标的组合映射到RGB颜色的组合。比如说,

d = {
    (0, 0): (0, 0, 0),
    (0, 1): (0, 0, 200),
    }

我想画一个热力图,在特定的XY坐标上,颜色是根据字典里的颜色计算出来的平均值,这个平均值是根据它们到这个点的距离来加权的;就好像它们是“光源”一样。

在这个例子中,坐标(0, 0.5)应该显示成(0, 0, 100)的颜色,而坐标(0, 0.1)应该显示成(0, 0, 20)的颜色。

我想问的其实是个技术问题:我该如何让pyplot用一个函数f(x, y) -> (r, g, b)来绘制一个像素图像,颜色是从这个函数中获取的呢?

1 个回答

4

如果你有一个X-Y坐标网格:

import numpy
from matplotlib import pyplot as plt

width, height = 300, 500

xs = numpy.arange(width)
ys = numpy.arange(height)
data = numpy.dstack(numpy.meshgrid(xs, ys))

你只需要把这些点映射到 (r, g, b) 的元组上。虽然下面的方法比较慢,但加快速度的方法取决于你的函数具体做了什么。

from colorsys import hsv_to_rgb

import random
def data_to_color(x, y):
    return (
        (x/width)**(0.5+random.random()*2),
        (y/height)**3,
        (x/width*y/height)*0.6 + random.random()*0.4
    )

colors = [[data_to_color(x, y) for x, y in row] for row in data]
colors = numpy.array(colors)
colors.shape
#>>> (500, 300, 3)

然后使用 imshow 就能得到想要的输出:

plt.imshow(colors, origin='lower')
plt.show()

plt.show() 结果

现在,如果你想从你的点进行插值,就可以使用 scipy.interpolate。我会创建一个字典来从上面的函数中推断:

from scipy.interpolate import griddata

gridpoints = data.reshape(width*height, 2)
d = {(x, y): data_to_color(x, y) for x, y in gridpoints if not random.randint(0, 1000)}

len(d)
#>>> 142

把字典提取到 numpy 数组中,并分开颜色(可能可以不分开,但你可以自己测试一下):

points, values = zip(*d.items())

points = numpy.array(points)
values = numpy.array(values)

reds   = values[:, 0]
greens = values[:, 1]
blues  = values[:, 2]

然后在这些点上运行 griddata

new_reds   = griddata(points, reds,   (data[:, :, 0], data[:, :, 1]), method='linear')
new_greens = griddata(points, greens, (data[:, :, 0], data[:, :, 1]), method='linear')
new_blues  = griddata(points, blues,  (data[:, :, 0], data[:, :, 1]), method='linear')

new_colors = numpy.dstack([new_reds, new_greens, new_blues])
new_colors[numpy.isnan(new_colors)] = 0.5

然后绘制图形:

plt.triplot(points[:,0], points[:,1], 'k-', linewidth=1, alpha=0.5)

plt.imshow(new_colors, extent=(0, width, 0, height), origin='lower')
plt.show()

plt.show() 插值输出

最后,如果你也想要 外推,我从 这里复制了一些代码:

import scipy

def extrapolate_nans(x, y, v):
    '''  
    Extrapolate the NaNs or masked values in a grid INPLACE using nearest
    value.

    .. warning:: Replaces the NaN or masked values of the original array!

    Parameters:

    * x, y : 1D arrays
        Arrays with the x and y coordinates of the data points.
    * v : 1D array
        Array with the scalar value assigned to the data points.

    Returns:

    * v : 1D array
        The array with NaNs or masked values extrapolated.
    '''

    if numpy.ma.is_masked(v):
        nans = v.mask
    else:
        nans = numpy.isnan(v)
    notnans = numpy.logical_not(nans)
    v[nans] = scipy.interpolate.griddata((x[notnans], y[notnans]), v[notnans],
        (x[nans], y[nans]), method='nearest').ravel()
    return v

new_reds   = extrapolate_nans(data[:, :, 0], data[:, :, 1], new_reds)
new_greens = extrapolate_nans(data[:, :, 0], data[:, :, 1], new_greens)
new_blues  = extrapolate_nans(data[:, :, 0], data[:, :, 1], new_blues)

new_colors = numpy.dstack([new_reds, new_greens, new_blues])

plt.imshow(new_colors, extent=(0, width, 0, height), origin='lower')
plt.show()

ply.show() 外推插值输出


编辑:也许可以更像这样:

import numpy
from matplotlib import pyplot as plt
from numpy.core.umath_tests import inner1d

width, height = 300, 500

xs, ys = numpy.mgrid[:width, :height]
coordinates = numpy.dstack([xs, ys])

light_sources = {
    (0, 0): (0, 0, 0),
    (300, 0): (0, 0, 0),
    (0, 0): (0, 0, 0),
    (300, 500): (0, 0, 0),
    (100, 0): (0, 0, 200),
    (200, 150): (100, 70, 0),
    (220, 400): (255, 255, 255),
    (80, 220): (255, 0, 0),
}

weights = numpy.zeros([width, height])
values = numpy.zeros([width, height, 3])

对于每个光源:

for coordinate, value in light_sources.items():

计算(反向)距离。使用 +1e9 来防止出现无穷大,尽管这会导致一些奇怪的错误,所以稍后需要更严格的修复:

    shifted_coordinates = coordinates - coordinate + 1e-9
    inverse_distances = (shifted_coordinates ** 2).sum(axis=-1) ** (-1/2)

把它加到总和和总权重中:

    weights += inverse_distances
    values  += inverse_distances[:, :, numpy.newaxis].repeat(3, axis=-1) * value / 255

然后用权重除以总和,得到平均值:

values /= weights[..., numpy.newaxis]

并显示...

plt.imshow(values, origin='lower')
plt.show()

对于这个:

plt.show() 光源变体

我最开始没有选择这个方法的原因是,因为在你的例子中 (0, 0.1) 的值不是 (0, 0, 20),而是:

distances = [0.9, 0.1]
inverse_distances = [10/9, 10]
sum_weighting = 100 / 9
blue_levels = 200 / (109/90) = 18

所以根据这个定义,它应该是 (0, 0, 18)

撰写回答