用于模型转换、序列化、加载等的机器学习实用程序
ml2rt的Python项目详细描述
用于模型转换、序列化、加载等的机器学习实用程序
- 免费软件:apache软件许可证2.0
安装
pip install ml2rt
文档
ml2rt提供了一些方便的功能来转换、保存和加载机器学习模型。它目前支持tensorflow、pytorch、sklearn、spark和onnx,但xgboost、coreml等框架正在开发中。
保存tensorflow模型
importtensorflowastffromml2rtimportsave_tensorflow# train your model heresess=tf.Session()save_tensorflow(sess,path,output=['output'])
保存pytorch模型
# it has to be a torchscript graph made by tracing / scriptingfromml2rtimportsave_torchsave_torch(torch_script_graph,path)
在NX型号上保存
fromml2rtimportsave_onnxsave_onnx(onnx_model,path)
保存sklearn模型
fromml2rtimportsave_sklearnprototype=np.array(some_shape,dtype=some_dtype)# Equivalent to the input of the modelsave_sklearn(sklearn_model,path,prototype=prototype)# or# some_shape has to be a tuple and some_dtype has to be a np.dtype, np.dtype.type or str objectsave_sklearn(sklearn_model,path,shape=some_shape,dtype=some_dtype)# or# some_shape has to be a tuple and some_dtype has to be a np.dtype, np.dtype.type or str objectinital_types=utils.guess_onnx_tensortype(shape=shape,dtype=dtype)save_sklearn(sklearn_model,path,initial_types=initial_types)
保存SPARKML模型
fromml2rtimportsave_sparkmlprototype=np.array(some_shape,dtype=some_dtype)# Equivalent to the input of the modelsave_sparkml(spark_model,path,prototype=prototype)# or# some_shape has to be a tuple and some_dtype has to be a np.dtype, np.dtype.type or str objectsave_sparkml(spark_model,path,shape=some_shape,dtype=some_dtype)# or# some_shape has to be a tuple and some_dtype has to be a np.dtype, np.dtype.type or str objectinital_types=utils.guess_onnx_tensortype(shape=shape,dtype=dtype)save_sparkml(spark_model,path,initial_types=initial_types)
sklearn和sparkml模型将首先转换为onnx,然后保存到磁盘。这些模型可以使用onnxruntime、redisai等执行。onnx转换需要知道输入节点的类型,因此我们必须传递shape&dtype或原型,从中实用程序可以推断shape&dtype或转换实用程序可以理解的初始类型对象。像sparkml这样的框架允许用户拥有具有多个类型的异构输入。在这种情况下,请使用guess onnx_tensortypes并创建一个以上的初始类型,这些类型可以传递给save function as a list
加载模型和脚本
model=ml2rt.load_model(path)script=ml2rt.load_script(script)