有没有办法在使用matplotlib保存的图表中添加额外信息?

1 投票
3 回答
47 浏览
提问于 2025-04-14 16:21

我用Python的turtle库做了一个模拟,展示了金星、地球、火星和木星围绕太阳的轨道。下面是我的代码:

import math
import matplotlib.pyplot as plt
import turtle
from turtle import *

...

def loop(bodies, earth):
    timestep = 3600 * 24 * 7
    step = 1
    earth_years = -1

    final_data = []
    for step in range(0, max_step + 1):
        update_info(step, bodies)

        if 0.98 <= earth.px / au <= 1.00 and -0.05 <= earth.py / au <= 0.05:
            earth_years += 1
            turtle.clear()
            turtle.penup()
            turtle.goto(-450, 350)
            turtle.pendown()
            turtle.color("white")
            turtle.hideturtle()
            turtle.write(f"Earth Years: {earth_years}", font=("Calibri", 20, "normal"))

        force = {}
        for body in bodies:
            total_fx = total_fy = 0.0
            for other in bodies:
                if body is other:
                    continue
                fx, fy = body.attraction(other)
                total_fx += fx
                total_fy += fy

            force[body] = (total_fx, total_fy)

        for body in bodies:
            fx, fy = force[body]
            body.vx += fx / body.mass * timestep
            body.vy += fy / body.mass * timestep

            if body.name != "Sun":
                body.px += body.vx * timestep
                body.py += body.vy * timestep
                body.goto(body.px * scale, body.py * scale)

        for body in bodies:
            text_values = f"Week: {step:3}  {body.name:4}   Position ={body.px/au:6.2f} {body.py/au:6.2f}  Velocity ={body.vx:10.2f} {body.vy:10.2f}"
            final_data.append(text_values)

        file.write(str(final_data) + "\n")
        final_data = []

    file.close()
    graph(earth_years, save_as_file=False)

...

def graph(earth_years, save_as_file=True):
    sim_data = open("C:/[FolderPath]/planetaryOrbits.txt", "r")
    plt.xlabel("AU")
    plt.ylabel("AU")
    plt.title("Orbits of Venus, Earth, Mars and Jupiter around the Sun")
    plt.axis([-6, 6, -6, 6])
    plt.grid()

    celestial_coordinates = []

    for z in sim_data:
        i = z.split()
        xs = float(i[5])
        ys = float(i[6])
        xv = float(i[16])
        yv = float(i[17])
        xe = float(i[27])
        ye = float(i[28])
        xm = float(i[38])
        ym = float(i[39])
        xj = float(i[49])
        yj = float(i[50])
        total = [xs, ys, xv, yv, xe, ye, xm, ym, xj, yj]
        celestial_coordinates.append(total)

    celestial_coordinates = list(zip(*celestial_coordinates))

    plt.plot(celestial_coordinates[0], celestial_coordinates[1], color = "yellow", marker = "o", label = "Sun")
    plt.plot(celestial_coordinates[2], celestial_coordinates[3], color = "purple", label = "Venus")
    plt.plot(celestial_coordinates[4], celestial_coordinates[5], color = "blue", label = "Earth")
    plt.plot(celestial_coordinates[6], celestial_coordinates[7], color = "red", label = "Mars")
    plt.plot(celestial_coordinates[8], celestial_coordinates[9], color = "green", label = "Jupiter")
    plt.legend(loc="upper right", fontsize=10)

    if save_as_file:
        try:
            # Show the AU values for each planet (only in the saved file)
            planet_au = {
                "Venus": max(celestial_coordinates[2]),
                "Earth": max(celestial_coordinates[4]),
                "Mars": max(celestial_coordinates[6]),
                "Jupiter": max(celestial_coordinates[8])
            }
            for planet, max_planet_au in planet_au.items():
                planet_color = {"Venus": "purple", "Earth": "blue", "Mars": "red", "Jupiter": "green"}
                if planet == "Venus":
                    plt.text(-0.68, +1.8, f"{max_planet_au:.2f} AU", fontsize=10, color=planet_color[planet], bbox=dict(facecolor="white", edgecolor="none"))
                elif planet == "Earth":
                    plt.text(+1.69, -0.15, f"{max_planet_au:.2f} AU", fontsize=10, color=planet_color[planet], bbox=dict(facecolor="white", edgecolor="none"))
                elif planet == "Mars":
                    plt.text(-2.65, -1.6, f"{max_planet_au:.2f} AU", fontsize=10, color=planet_color[planet], bbox=dict(facecolor="white", edgecolor="none"))
                elif planet == "Jupiter":
                    plt.text(-3.6, +3.1, f"{max_planet_au:.2f} AU", fontsize=10, color=planet_color[planet])
            # Show the total Earth years (only in the saved file)
            earth_position = celestial_coordinates[4], celestial_coordinates[5]
            earth_orbit_completion = math.atan2(earth_position[1][-1], earth_position[0][-1]) / (2 * math.pi)
            plt.text(-2.35, -3.21, f"Total Earth Years: {earth_years + 1 + earth_orbit_completion:.2f}", fontsize=12, color="black", bbox=dict(facecolor="white", edgecolor="none"))
            plt.savefig("C:/[FolderPath]/planetaryOrbits.png")
        except Exception as e:
            print(f"Error saving the file: {e}")
    try:
        plt.show()
    except Exception as e:
        print(f"Error showing the graph: {e}")

(目前图表没有保存,因为我把“save_as_file”设置为False

我想要的是,当模拟结束后,能显示一个互动图表(这部分是可以做到的),但我还希望这个图表能保存下来,并且包含一些在互动图表上看不到的信息。所以我希望互动图表只显示轨道和图例,而保存的文件里要有每个行星的天文单位(AU)和总的地球年数。

我尝试了不同的方法,但要么两个图表显示的信息完全一样,要么我只能选择显示只有轨道和图例的图表,或者保存包含额外AU和总地球年数的图表。

如果能得到一些帮助,我会非常感激。

3 个回答

0

先画出简化的视图,展示出来,然后在关闭这个互动窗口后,再添加其他所有内容,最后用 savefig 来保存。

In [8]: plt.cla()
   ...: plt.plot()
   ...: plt.show()
   ...: plt.title('Title')
   ...: plt.savefig('Figure2.png', dpi=185)
   ...: !qiv Figure2.png

(qiv 是一个简单的图片查看器,可以根据你的需求和喜好换成其他的查看器)。

0

如果你想保留这个互动图表,并且想在保存的文件里添加更多信息,你可以使用pickle这个工具来保存结果:

import numpy as np
import matplotlib.pyplot as plt
import pickle as pickle

# Plot simple sinus function
fig = plt.figure()
x = np.linspace(0,2*np.pi)
y = np.sin(x)
plt.plot(x,y)

extra_data = [1,2,3]
# Save figure handle to disk
import pickle
with open('figure.pickle', 'wb') as f: # should be 'wb' rather than 'w'
    pickle.dump([fig,extra_data], f)
plt.close(fig)



# Load figure handle from disk    
with open('figure.pickle','rb') as file:
    loaded_fig,loaded_extra_data = pickle.load(file)
    
plt.show()
0

如果你记得自己添加的额外艺术家(比如图表上的一些元素),那么在你调用 savefig 之后,可以简单地把它们去掉:

import matplotlib.pyplot as plt

fig, ax = plt.subplots()

ax.plot([0, 1])

save_as_file = True

if save_as_file:
    extra_artists = []
    extra_artists.append(ax.text(0, 0, 'foo'))
    extra_artists.append(ax.text(1, 0, 'bar'))

    plt.savefig("test.png")

    for artist in extra_artists:
        artist.remove()
        
plt.show()

撰写回答