火把模型的包装纸
torchwrapper的Python项目详细描述
火炬包装器
使用fit和predict函数的pytorhc模型的包装类 使用Keras和Sklearn的人很熟悉
减少了为基本模型编写拟合和评估函数的需要。
快速启动
# import the modulefromtorchwrapperimportWrapper# create your module, optimizer, and criterion functionmodel=Model()optimizer=torch.optim.Adam(model.parameters())criterion=torch.nn.MSELos()# wrap the modelmodel=Wrapper(model)# train the networkmodel.fit(dataloader,optimizer,criterion,epochs=50)
使用经过训练的模型,您可以使用pytorch数据加载器进行预测:
preds=model.predict(dataloader)
这将返回一个预测的numpy数组。