我有一个简单的单元测试,我检查是否可以用稍微不同的参数实例化我的Tensorflow类。这似乎是@pytest.mark.parametrize
的一个很好的用例。在
但是,我发现如果我的单元测试是tf.test.TestCase
的方法,那么{
例如,当我对以下代码运行pytest
时:
class TestBasicRewardNet(tf.test.TestCase):
@pytest.mark.parametrize("env", ['FrozenLake-v0', 'CartPole-v1',
'CarRacing-v0', 'LunarLander-v2'])
def test_init_no_crash(self, env):
for i in range(3):
x = BasicRewardNet(env)
我得到错误TypeError: test_init_no_crash() missing 1 required positional argument: 'env'
。在
为了解决这个问题,我试图去掉类包装器,但这使我错过了一些自动的Tensorflow测试初始化。特别是,现在每个BasicRewardNet
都构建在同一张TensorFlow图中,因此我需要做一些事情,比如添加一个变量范围来避免
冲突。在这个变量范围中添加似乎很难。在
我想知道有没有人知道一种方法可以让我两全其美?我希望能够使用parametrize
,同时获得{
目前没有回答
相关问题 更多 >
编程相关推荐