递归:如何在迭代时避免Python集合修改运行时错误
背景和问题描述:
我有一些代码是用来解决图着色问题的(简单来说,就是给一个无向图的每个点“上色”,确保相连的两个点不能是同一种颜色)。我想用约束传播的方法来提高标准递归回溯算法的效率,但遇到了以下错误:
File "C:\Users\danisg\Desktop\coloring\Solver.py",
line 99, in solve
for color in self.domains[var]:
RuntimeError: Set changed size during iteration
在这里,对于每个点,我保持一个set
,里面存放这个点可能的具体值:
self.domains = { var: set(self.colors) for var in self.vars }
在我做出一个赋值后,我会把这个约束传播到相邻的点,以限制搜索空间:
for key in node.neighbors: # list of keys corresponding to adjacent vertices
if color in self.domains[key]: # remove now to prune possible choices
self.domains[key].remove(color)
这里并不是实际抛出错误的地方(在我的代码中,我在try-except
块里标明了问题所在),但可能是问题的根源。
我的问题:
我这样做的思路对吗?如果实现不对,我该如何修正?另外,是否有必要单独保持一个domains
字典?或者我们可以把domain
作为图中每个节点的一个属性吗?
我的代码:
这是调用这段代码的solve
函数:
def solve(self):
uncolored = [var for var in self.vars if self.map[var].color == None]
if len(uncolored) == 0:
return True
var = min(uncolored, key = lambda x: len(self.domains[var]))
node = self.map[var]
old = { var: set(self.domains[var]) for var in self.vars }
for color in self.domains[var]:
if not self._valid(var, color):
continue
self.map[var].color = color
for key in node.neighbors:
if color in self.domains[key]:
self.domains[key].remove(color)
try:
if self.solve():
return True
except:
print('happening now')
self.map[var].color = None
self.domains = old
return False
我的实现使用了一个Node
对象:
class Solver:
class Node:
def __init__(self, var, neighbors, color = None, domain = set()):
self.var = var
self.neighbors = neighbors
self.color = color
self.domain = domain
def __str__(self):
return str((self.var, self.color))
def __init__(self, graph, K):
self.vars = sorted( graph.keys(), key = lambda x: len(graph[x]), reverse = True ) # sort by number of links; start with most constrained
self.colors = range(K)
self.map = { var: self.Node(var, graph[var]) for var in self.vars }
self.domains = { var: set(self.colors) for var in self.vars }
这里还有两个其他的函数,它们是有用的:
def validate(self):
for var in self.vars:
node = self.map[var]
for key in node.neighbors:
if node.color == self.map[key].color:
return False
return True
def _valid(self, var, color):
node = self.map[var]
for key in node.neighbors:
if self.map[key].color == None:
continue
if self.map[key].color == color:
return False
return True
代码失败的数据和示例:
我使用的示例图可以在这里找到。
读取数据的函数:
def read_and_make_graph(input_data):
lines = input_data.split('\n')
first_line = lines[0].split()
node_count = int(first_line[0])
edge_count = int(first_line[1])
graph = {}
for i in range(1, edge_count + 1):
line = lines[i]
parts = line.split()
node, edge = int(parts[0]), int(parts[1])
if node in graph:
graph[node].add(edge)
if edge in graph:
graph[edge].add(node)
if node not in graph:
graph[node] = {edge}
if edge not in graph:
graph[edge] = {node}
return graph
它应该这样调用:
file_location = 'C:\\Users\\danisg\\Desktop\\coloring\\data\\gc_50_3'
input_data_file = open(file_location, 'r')
input_data = ''.join(input_data_file.readlines())
input_data_file.close()
graph = read_and_make_graph(input_data)
solver = Solver(graph, 6) # a 6 coloring IS possible
print(solver.solve()) # True if we solved; False if we didn't
2 个回答
1
使用 set() 对象的 copy() 方法来解决这个问题。
class Door():
def __init__(self,id):
self.id = id
if __name__ == '__main__':
cache_door = set()
cache_door.add(Door(1))
cache_door.add(Door(2))
cache_door.add(Door(3))
cache_door.add(Door(4))
print cache_door
for door in cache_door.copy():
if door.id == 1:
cache_door.remove(door)
print cache_door
70
我觉得问题出在这里:
for color in self.domains[var]:
if not self._valid(var, color):
continue
self.map[var].color = color
for key in node.neighbors:
if color in self.domains[key]:
self.domains[key].remove(color) # This is potentially bad.
当你调用 self.domains[key].remove(color)
时,如果 key == var
,你就会改变正在遍历的集合的大小。为了避免这个问题,你可以使用
for color in self.domains[var].copy():
使用 copy() 可以让你在遍历集合的副本时,从原始集合中删除项目。