我正在尝试执行一个函数,该函数是我在另一个脚本中导入并实例化的类的方法-但是,终端返回以下错误
以下是我的代码流程(请忽略缩进错误)。。。你知道吗
我创建了DataProcess类,其中包含一个在第一个文件中加载数据的函数
class dataProcess(object):
def __init__(self, out_rows, out_cols, data_path = "./data/train/img", ... img_type = "png"):
def load_data(self):
mydata = dataProcess(self.img_rows, self.img_cols)
imgs_train, imgs_mask_skulls = mydata.load_train_data()
imgs_test = mydata.load_test_data()
return imgs_train, imgs_mask_skulls, imgs_test
然后在另一个文件中,我尝试实例化这个类并调用load\u data函数。你知道吗
from dataProcess import *
from dataPreperation import *
from myUnet import *
class runUnet():
def __init__(self, img_rows=img_rows, img_cols=img_cols):
self.img_rows = img_rows # set values for these as default in definition arguments or as shape of input data
self.img_cols = img_cols
def train_and_predict(self):
print("loading data")
mydata = dataProcess(self.img_rows, self.img_cols)
imgs_train, imgs_mask_skulls, imgs_test = mydata.load_data()
print("loading data done")
myUnet = myUnet(self.img_rows, self.img_cols)
model = myUnet.get_unet()
print("got unet")
model_checkpoint = ModelCheckpoint('unet.hdf5',
monitor='loss',
verbose=1,
save_best_only=True)
print('Fitting model...')
model.fit(imgs_train, imgs_mask_skulls,
batch_size=4,
nb_epoch=2,
verbose=1,
validation_split=0.2, # validation_split vs validation_data
shuffle=True,
callbacks=[model_checkpoint])
print('predict test data')
imgs_mask_test = model.predict(imgs_test, batch_size=1, verbose=1)
print(imgs_mask_test.shape())
print(imgs_mask_test)
np.save("./results/imgs_mask_test.npy", imgs_mask_test)
目前没有回答
相关问题 更多 >
编程相关推荐