单例Python生成器?或者,序列化Python生成器?

4 投票
6 回答
4691 浏览
提问于 2025-04-15 17:14

我正在使用以下代码,通过嵌套生成器来遍历一个文本文件,并利用 get_train_minibatch() 返回训练示例。我想要保存(序列化)这些生成器,以便能在文本文件中回到之前的位置。不过,生成器是不能被序列化的。

  • 有没有简单的方法可以让我保存当前位置,并从停止的地方继续?也许我可以把 get_train_example() 变成单例,这样就不会有多个生成器在那儿了。然后,我可以在这个模块里创建一个全局变量,来跟踪 get_train_example() 进行到哪儿了。

  • 你有没有更好的(更简洁的)建议,来让我能够保存这个生成器的状态?

[编辑:还有两个想法:

  • 我可以给生成器添加一个成员变量/方法吗?这样我就可以调用 generator.tell() 来找到文件的位置?因为这样,下次我创建生成器时,可以让它跳转到那个位置。这个想法听起来是所有想法中最简单的。

  • 我可以创建一个类,把文件位置作为一个成员变量,然后在类里创建生成器,并在每次生成时更新文件位置的成员变量吗?这样我就能知道它在文件中走了多远。

]

这是代码:

def get_train_example():
    for l in open(HYPERPARAMETERS["TRAIN_SENTENCES"]):
        prevwords = []
        for w in string.split(l):
            w = string.strip(w)
            id = None
            prevwords.append(wordmap.id(w))
            if len(prevwords) >= HYPERPARAMETERS["WINDOW_SIZE"]:
                yield prevwords[-HYPERPARAMETERS["WINDOW_SIZE"]:]

def get_train_minibatch():
    minibatch = []
    for e in get_train_example():
        minibatch.append(e)
        if len(minibatch) >= HYPERPARAMETERS["MINIBATCH SIZE"]:
            assert len(minibatch) == HYPERPARAMETERS["MINIBATCH SIZE"]
            yield minibatch
            minibatch = []

6 个回答

2

你可以创建一个标准的迭代器对象,不过它没有生成器那么方便。你需要在实例中保存迭代器的状态(这样它才能被序列化),并定义一个 next() 函数来返回下一个对象:

class TrainExampleIterator (object):
    def __init__(self):
        # set up internal state here
        pass
    def next(self):
        # return next item here
        pass

迭代器协议就是这么简单,只需要在一个对象上定义 .next() 方法,就可以把它用在 for 循环等地方。

在 Python 3 中,迭代器协议使用 __next__ 方法(这样更一致一些)。

2

下面的代码大致上可以满足你的需求。第一个类定义了一种像文件一样的东西,但它可以被“腌制”(也就是保存成一种特殊的格式)。当你把它“腌制”后再“解腌制”,它会重新打开文件,并回到你腌制时的位置。第二个类是一个迭代器,用来生成单词窗口。

class PickleableFile(object):
    def __init__(self, filename, mode='rb'):
        self.filename = filename
        self.mode = mode
        self.file = open(filename, mode)
    def __getstate__(self):
        state = dict(filename=self.filename, mode=self.mode,
                     closed=self.file.closed)
        if not self.file.closed:
            state['filepos'] = self.file.tell()
        return state
    def __setstate__(self, state):
        self.filename = state['filename']
        self.mode = state['mode']
        self.file = open(self.filename, self.mode)
        if state['closed']: self.file.close()
        else: self.file.seek(state['filepos'])
    def __getattr__(self, attr):
        return getattr(self.file, attr)

class WordWindowReader:
    def __init__(self, filenames, window_size):
        self.filenames = filenames
        self.window_size = window_size
        self.filenum = 0
        self.stream = None
        self.filepos = 0
        self.prevwords = []
        self.current_line = []

    def __iter__(self):
        return self

    def next(self):
        # Read through files until we have a non-empty current line.
        while not self.current_line:
            if self.stream is None:
                if self.filenum >= len(self.filenames):
                    raise StopIteration
                else:
                    self.stream = PickleableFile(self.filenames[self.filenum])
                    self.stream.seek(self.filepos)
                    self.prevwords = []
            line = self.stream.readline()
            self.filepos = self.stream.tell()
            if line == '':
                # End of file.
                self.stream = None
                self.filenum += 1
                self.filepos = 0
            else:
                # Reverse line so we can pop off words.
                self.current_line = line.split()[::-1]

        # Get the first word of the current line, and add it to
        # prevwords.  Truncate prevwords when necessary.
        word = self.current_line.pop()
        self.prevwords.append(word)
        if len(self.prevwords) > self.window_size:
            self.prevwords = self.prevwords[-self.window_size:]

        # If we have enough words, then return a word window;
        # otherwise, go on to the next word.
        if len(self.prevwords) == self.window_size:
            return self.prevwords
        else:
            return self.next()

撰写回答