使用networkx在两个节点之间绘制多条边
我想画一个有向图,也就是在两个节点之间可以有多条边,而且这些边的权重可以不同。比如说,我有两个节点A和B,边(A,B)的长度是2,而边(B,A)的长度是3。
我试过用G=nx.Digraph和G=nx.Multidigraph这两种方法来画这个图。但是画出来后,我只能看到一条边和一个标签。请问有没有办法做到这一点呢?
6 个回答
在AMangipinto的解决方案中添加以下代码,以便在两个方向上都显示边的标签(查看链接中的图片):
edge_labels = dict([((u, v,), f'{d["length"]}\n\n{G.edges[(v,u)]["length"]}')
for u, v, d in G.edges(data=True) if pos[u][0] > pos[v][0]])
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color='red')
这里的“if pos[u][0] > pos[v][0]”只会在一个方向上添加边的标签。我们将两个长度加到一个标签上,否则会把边的第一个标签覆盖掉。注意:如果节点的x坐标相同,标签将不会显示。
你可以直接使用你计算出来的节点位置来使用matplotlib这个库。
G=nx.MultiGraph ([(1,2),(1,2),(1,2),(3,1),(3,2)])
pos = nx.random_layout(G)
nx.draw_networkx_nodes(G, pos, node_color = 'r', node_size = 100, alpha = 1)
ax = plt.gca()
for e in G.edges:
ax.annotate("",
xy=pos[e[0]], xycoords='data',
xytext=pos[e[1]], textcoords='data',
arrowprops=dict(arrowstyle="->", color="0.5",
shrinkA=5, shrinkB=5,
patchA=None, patchB=None,
connectionstyle="arc3,rad=rrr".replace('rrr',str(0.3*e[2])
),
),
)
plt.axis('off')
plt.show()
试试下面这个:
import networkx as nx
import matplotlib.pyplot as plt
G = nx.DiGraph() #or G = nx.MultiDiGraph()
G.add_node('A')
G.add_node('B')
G.add_edge('A', 'B', length = 2)
G.add_edge('B', 'A', length = 3)
pos = nx.spring_layout(G)
nx.draw(G, pos)
edge_labels=dict([((u,v,),d['length'])
for u,v,d in G.edges(data=True)])
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, label_pos=0.3, font_size=7)
plt.show()
这样做会给你返回一个图,上面有两条边,并且边的长度会显示出来:
对上面的回答进行了一些改进,就是在nx.draw里添加了一个叫做 connectionstyle 的东西,这样在图表中就能看到两条平行的线了:
import networkx as nx
import matplotlib.pyplot as plt
G = nx.DiGraph() #or G = nx.MultiDiGraph()
G.add_node('A')
G.add_node('B')
G.add_edge('A', 'B', length = 2)
G.add_edge('B', 'A', length = 3)
pos = nx.spring_layout(G)
nx.draw(G, pos, with_labels=True, connectionstyle='arc3, rad = 0.1')
edge_labels=dict([((u,v,),d['length'])
for u,v,d in G.edges(data=True)])
plt.show()
下面是如何得到类似于以下效果的方法:
这个图形的一些特点是:
- 当两个节点之间只有一条边时,这条边是直的。
- 标签正好位于边的中间。
- 目前的解决方案只适用于有向图(DiGraphs)。不支持多重图(MultiGraphs)、多重有向图(MultiDiGraphs)和自环。
设置环境
以下代码是开始这个示例的初始代码:
import matplotlib.pyplot as plt
import networkx as nx
G = nx.DiGraph()
edge_list = [(1,2,{'w':'A1'}),(2,1,{'w':'A2'}),(2,3,{'w':'B'}),(3,1,{'w':'C'}),
(3,4,{'w':'D1'}),(4,3,{'w':'D2'}),(1,5,{'w':'E1'}),(5,1,{'w':'E2'}),
(3,5,{'w':'F'}),(5,4,{'w':'G'})]
G.add_edges_from(edge_list)
pos=nx.spring_layout(G,seed=5)
fig, ax = plt.subplots()
nx.draw_networkx_nodes(G, pos, ax=ax)
nx.draw_networkx_labels(G, pos, ax=ax)
fig.savefig("1.png", bbox_inches='tight',pad_inches=0)
这段代码的结果是:
绘制边
NetworkX中的draw_networkx_edges
函数只能用edgelist
参数绘制部分边。为了使用这个功能,我们把边分成两个列表,分别绘制。感谢AMangipinto的回答,使用了connectionstyle='arc3, rad = 0.1'
。
curved_edges = [edge for edge in G.edges() if reversed(edge) in G.edges()]
straight_edges = list(set(G.edges()) - set(curved_edges))
nx.draw_networkx_edges(G, pos, ax=ax, edgelist=straight_edges)
arc_rad = 0.25
nx.draw_networkx_edges(G, pos, ax=ax, edgelist=curved_edges, connectionstyle=f'arc3, rad = {arc_rad}')
fig.savefig("2.png", bbox_inches='tight',pad_inches=0)
这段代码的结果是:
绘制边标签
NetworkX中的draw_networkx_edge_labels
函数假设边是直的,并且没有参数可以改变这一点。由于NetworkX是开源的,我复制了这个函数,并创建了一个修改版的my_draw_networkx_edge_labels
。这个函数在附录中。
假设你把这个函数保存到一个名为my_networkx.py的文件中,你可以这样绘制边标签:
import my_networkx as my_nx
edge_weights = nx.get_edge_attributes(G,'w')
curved_edge_labels = {edge: edge_weights[edge] for edge in curved_edges}
straight_edge_labels = {edge: edge_weights[edge] for edge in straight_edges}
my_nx.my_draw_networkx_edge_labels(G, pos, ax=ax, edge_labels=curved_edge_labels,rotate=False,rad = arc_rad)
nx.draw_networkx_edge_labels(G, pos, ax=ax, edge_labels=straight_edge_labels,rotate=False)
fig.savefig("3.png", bbox_inches='tight',pad_inches=0)
我们再次将曲线和直线分开。结果是这个回答中的第一个图形。
附录
NetworkX中的draw_networkx_edge_labels
函数找到标签的位置时假设边是直线:
(x, y) = (
x1 * label_pos + x2 * (1.0 - label_pos),
y1 * label_pos + y2 * (1.0 - label_pos),
)
为了找到二次贝塞尔曲线的中点,我们可以使用以下代码。首先,我们根据matplotlib中的定义,找到贝塞尔曲线的中间控制点(代码中的ctrl_1
):
这个曲线的创建方式是中间控制点(C1)与起点(C0)和终点(C2)之间的距离相等,并且C1到连接C0-C2的线的距离是C0-C2距离的rad倍。
由于这个定义,my_draw_networkx_edge_labels
函数需要一个额外的参数,叫做rad
。
pos_1 = ax.transData.transform(np.array(pos[n1]))
pos_2 = ax.transData.transform(np.array(pos[n2]))
linear_mid = 0.5*pos_1 + 0.5*pos_2
d_pos = pos_2 - pos_1
rotation_matrix = np.array([(0,1), (-1,0)])
ctrl_1 = linear_mid + rad*rotation_matrix@d_pos
以“ax.transData”开头的函数是必要的,因为坐标轴上的90度角在显示中并不对应90度。因此,我们需要在显示坐标系统和实际坐标系统之间转换坐标。
bezier_mid
可以通过贝塞尔曲线的规则计算:
ctrl_mid_1 = 0.5*pos_1 + 0.5*ctrl_1
ctrl_mid_2 = 0.5*pos_2 + 0.5*ctrl_1
bezier_mid = 0.5*ctrl_mid_1 + 0.5*ctrl_mid_2
(x, y) = ax.transData.inverted().transform(bezier_mid)
完整的my_draw_networkx_edge_labels
代码:
def my_draw_networkx_edge_labels(
G,
pos,
edge_labels=None,
label_pos=0.5,
font_size=10,
font_color="k",
font_family="sans-serif",
font_weight="normal",
alpha=None,
bbox=None,
horizontalalignment="center",
verticalalignment="center",
ax=None,
rotate=True,
clip_on=True,
rad=0
):
"""Draw edge labels.
Parameters
----------
G : graph
A networkx graph
pos : dictionary
A dictionary with nodes as keys and positions as values.
Positions should be sequences of length 2.
edge_labels : dictionary (default={})
Edge labels in a dictionary of labels keyed by edge two-tuple.
Only labels for the keys in the dictionary are drawn.
label_pos : float (default=0.5)
Position of edge label along edge (0=head, 0.5=center, 1=tail)
font_size : int (default=10)
Font size for text labels
font_color : string (default='k' black)
Font color string
font_weight : string (default='normal')
Font weight
font_family : string (default='sans-serif')
Font family
alpha : float or None (default=None)
The text transparency
bbox : Matplotlib bbox, optional
Specify text box properties (e.g. shape, color etc.) for edge labels.
Default is {boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)}.
horizontalalignment : string (default='center')
Horizontal alignment {'center', 'right', 'left'}
verticalalignment : string (default='center')
Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'}
ax : Matplotlib Axes object, optional
Draw the graph in the specified Matplotlib axes.
rotate : bool (deafult=True)
Rotate edge labels to lie parallel to edges
clip_on : bool (default=True)
Turn on clipping of edge labels at axis boundaries
Returns
-------
dict
`dict` of labels keyed by edge
Examples
--------
>>> G = nx.dodecahedral_graph()
>>> edge_labels = nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G))
Also see the NetworkX drawing examples at
https://networkx.org/documentation/latest/auto_examples/index.html
See Also
--------
draw
draw_networkx
draw_networkx_nodes
draw_networkx_edges
draw_networkx_labels
"""
import matplotlib.pyplot as plt
import numpy as np
if ax is None:
ax = plt.gca()
if edge_labels is None:
labels = {(u, v): d for u, v, d in G.edges(data=True)}
else:
labels = edge_labels
text_items = {}
for (n1, n2), label in labels.items():
(x1, y1) = pos[n1]
(x2, y2) = pos[n2]
(x, y) = (
x1 * label_pos + x2 * (1.0 - label_pos),
y1 * label_pos + y2 * (1.0 - label_pos),
)
pos_1 = ax.transData.transform(np.array(pos[n1]))
pos_2 = ax.transData.transform(np.array(pos[n2]))
linear_mid = 0.5*pos_1 + 0.5*pos_2
d_pos = pos_2 - pos_1
rotation_matrix = np.array([(0,1), (-1,0)])
ctrl_1 = linear_mid + rad*rotation_matrix@d_pos
ctrl_mid_1 = 0.5*pos_1 + 0.5*ctrl_1
ctrl_mid_2 = 0.5*pos_2 + 0.5*ctrl_1
bezier_mid = 0.5*ctrl_mid_1 + 0.5*ctrl_mid_2
(x, y) = ax.transData.inverted().transform(bezier_mid)
if rotate:
# in degrees
angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360
# make label orientation "right-side-up"
if angle > 90:
angle -= 180
if angle < -90:
angle += 180
# transform data coordinate angle to screen coordinate angle
xy = np.array((x, y))
trans_angle = ax.transData.transform_angles(
np.array((angle,)), xy.reshape((1, 2))
)[0]
else:
trans_angle = 0.0
# use default box of white with white border
if bbox is None:
bbox = dict(boxstyle="round", ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0))
if not isinstance(label, str):
label = str(label) # this makes "1" and 1 labeled the same
t = ax.text(
x,
y,
label,
size=font_size,
color=font_color,
family=font_family,
weight=font_weight,
alpha=alpha,
horizontalalignment=horizontalalignment,
verticalalignment=verticalalignment,
rotation=trans_angle,
transform=ax.transData,
bbox=bbox,
zorder=1,
clip_on=clip_on,
)
text_items[(n1, n2)] = t
ax.tick_params(
axis="both",
which="both",
bottom=False,
left=False,
labelbottom=False,
labelleft=False,
)
return text_items