张量流图
tensorflow-plot的Python项目详细描述
张量流图(tfplot)
用于提供基于matplotlib的plot操作的TensorFlow实用程序 -TensorBoard第{a6}页。
在建-API可能会改变!
它允许我们将任何matplotlib图或图形绘制成图像, 作为张量流计算图的一部分。 特别是,我们可以很容易地绘制和查看结果图像 作为TensorBoard中的图像摘要。
快速概述
使用tfplot
有两种主要方法:(i)用作tf op,和(ii)手动添加摘要proto。
用法:decorator
通过使用^{
@tfplot.autowrap(figsize=(2,2))defplot_scatter(x:np.ndarray,y:np.ndarray,*,ax,color='red'):ax.scatter(x,y,color=color)x=tf.constant([1,2,3],dtype=tf.float32)# tf.Tensory=tf.constant([1,4,9],dtype=tf.float32)# tf.Tensorplot_op=plot_scatter(x,y)# tf.Tensor shape=(?, ?, 4) dtype=uint8
用法:包装成tf ops
我们可以wrap任何纯python函数,用于绘制TensorFlow操作,例如:
- (i)创建并返回matplotlib
Figure
(见下文)的python函数 - (ii)具有
fig
或ax
关键字参数的python函数(将被自动注入); 例如^{} - (iii)方法实例matplotlib ^{
} ; 例如^{}
(i)的示例:可以定义一个python函数,它将numpy.ndarray
值作为输入(作为张量输入的参数)。
并绘制一个图作为matplotlib.figure.Figure
的返回值。
生成的tensorflow plot op将是包含生成的plot的形状[height, width, 4]
的rgba图像张量。
deffigure_heatmap(heatmap,cmap='jet'):# draw a heatmap with a colorbarfig,ax=tfplot.subplots(figsize=(4,3))# DON'T USE plt.subplots() !!!!im=ax.imshow(heatmap,cmap=cmap)fig.colorbar(im)returnfigheatmap_tensor=...# tf.Tensor shape=(16, 16) dtype=float32# (a) wrap function as a Tensor factoryplot_op=tfplot.autowrap(figure_heatmap)(heatmap_tensor)# tf.Tensor shape=(?, ?, 4) dtype=uint8# (b) direct invocation similar to tf.py_funcplot_op=tfplot.plot(figure_heatmap,[heatmap_tensor],cmap='jet')# (c) or just directly add an image summary with the plottfplot.summary.plot("heatmap_summary",figure_heatmap,[heatmap_tensor])
(ii)示例:
importtfplotimportseaborn.apionlyassnstf_heatmap=tfplot.autowrap(sns.heatmap,figsize=(4,4),batch=True)# function: Tensor -> Tensorplot_op=tf_heatmap(attention_maps)# tf.Tensor shape=(?, 400, 400, 4) dtype=uint8tf.summary.image("attention_maps",plot_op)
请查看the showcase或examples directory以获取更多示例和用例。
The full documentation包括api文档可以在readthedocs找到。
用法:手动添加摘要协议
importtensorboardastbfig,ax=...# Get RGB image manually or by executing plot ops.embedding_plot=sess.run(plot_op)# ndarray [H, W, 3] uint8embedding_plot=tfplot.figure_to_array(fig)# ndarray [H, W, 3] uint8summary_pb=tb.summary.image_pb('plot_embedding',[embedding_plot])summary_writer.write_add_summary(summary_pb,global_step=global_step)
安装
pip install tensorflow-plot
获取最新的开发版本:
pip install git+https://github.com/wookayin/tensorflow-plot.git@master
注
关于速度
的一些评论matplotlib操作可能非常慢,因为matplotlib是在python中运行的,而不是在本机代码中运行的, 所以请注意运行速度。 还有改进的余地,这将在不久的将来得到解决。
此外,从主代码中绘制绘图(而不是使用tf op)并将其作为图像摘要添加也是一个好主意。 请尽你所能使用这个图书馆。
线程安全问题
请使用面向对象的matplotlib api(例如Figure
,AxesSubplot
)
而不是pyplotapi(即matplotlib.pyplot
或plt.XXX()
)
创建和绘制绘图时。
这是因为pyplotapi不是线程安全的,
而tensorflow绘图操作通常以多线程方式执行。
例如,避免使用pyplot
(或plt
):
# DON'T DO LIKE THIS !!!deffigure_heatmap(heatmap):fig=plt.figure()# <--- NO!plt.imshow(heatmap)returnfig
然后像这样做:
deffigure_heatmap(heatmap):fig=matplotlib.figure.Figure()# or just `fig = tfplot.Figure()`ax=fig.add_subplot(1,1,1)# ax: AxesSubplot# or, just `fig, ax = tfplot.subplots()`ax.imshow(heatmap)returnfig# fig: Figure
例如,tfplot.subplots()
是plt.subplots()
的一个很好的替代品。
使用内部绘图功能。
或者,您可以利用自动注入fig
和/或ax
的优势。
tensorflow兼容性
目前,tfplot
与tensorflow 1.x系列兼容。
支持紧急执行和tf 2.0即将到来!
许可证
MIT License崔宗武