RAFT(光流的循环全对场变换)通过克拉斯特遣部队
tf-raft的Python项目详细描述
tf筏
RAFT(光流的重复全对场变换,Teed等,ECCV2020)通过特斯拉斯在
原始资源
安装
$ pip install tf-raft
或者您可以简单地克隆这个存储库。在
依赖性
- 张量流
- TensorFlow插件
- 专辑
详见pyoroject.toml
光流数据集
MPI-Sintel或{a4}数据集相对较少。在oirignal repository中查看更多数据集
使用
^{pr2}$实际上,您需要准备数据集、优化器、回调等,检查train_sintel.py
或{
通过YAML配置进行列车
train_chairs.py
和{configs
目录中。运行
$ python train_chairs.py /path/to/config.yml
预先培训的模型
我把预先训练过的重量(飞椅和MPI-Sintel上的)公之于众。
您可以通过gsutil
或curl
下载它们。在
在飞行椅上训练重量
$ gsutil cp -r gs://tf-raft-pretrained/2020-09-26T18-38/checkpoints .
或者
$ mkdir checkpoints $ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T18-38/checkpoints/model.data-00000-of-00001 $ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T18-38/checkpoints/model.index $ mv model* checkpoints/
MPI Sintel(清洁路径)上的训练重量
$ gsutil cp -r gs://tf-raft-pretrained/2020-09-26T08-51/checkpoints .
或者
$ mkdir checkpoints $ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T08-51/checkpoints/model.data-00000-of-00001 $ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T08-51/checkpoints/model.index $ mv model* checkpoints/
负载重量
raft=RAFT(iters=iters,iters_pred=iters_pred)raft.load_weights('checkpoints/model')# forward (with dummy inputs)x1=np.random.uniform(0,255,(1,448,512,3)).astype(np.float32)x2=np.random.uniform(0,255,(1,448,512,3)).astype(np.float32)flow_predictions=model([x1,x2],training=False)print(flow_predictions[-1].shape)# >> (1, 448, 512, 2)
注意
虽然我已经尝试忠实地复制原始实现,但是原始实现与我的实现之间存在一些差异(主要是因为使用的框架:PyTorch/TensorFlow)
- 最初的实现提供了基于cuda的关联函数,但我没有。我的基于TF的实现很好,但是基于cuda的实现可能运行得更快。在
- 我在我的私人环境(GCP和P100加速器)中分别训练了我的模型飞椅和MPI Sintel。该模型经过了很好的训练,但没有达到本文报道的最佳成绩(在多个数据集上进行训练)。在
- 原来的方法使用混合精度。这可能会更快地得到训练,但我没有。TensorFlow还支持混合精度,只需几行额外的行,如果感兴趣,请参阅https://www.tensorflow.org/guide/mixed_precision。在
另外,全局梯度削波似乎是稳定训练的必要条件,尽管在原论文中没有强调这一点。这个操作可以通过PyTorch中的torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
来完成,TF中的tf.clip_by_global_norm(grads, clip_norm)
(在tf_raft/model.py
中编码为self.train_step
)。在
参考文献
- https://github.com/princeton-vl/RAFT
- https://github.com/NVIDIA/flownet2-pytorch
- https://github.com/NVlabs/PWC-Net
- 项目
标签: