如何使该方法返回一组Point数据类?

0 投票
1 回答
15 浏览
提问于 2025-04-14 17:46

我希望我的 visited() 方法返回一个包含 Point 的集合,其中 Point 是一个数据类,里面有三个属性,分别是 x、y 和 z。我写的代码试图实现这个功能,但出现了这个错误:TypeError: unhashable type: 'Point'。我该如何解决这个错误?把 Point 转换成一个元组,提取它的各个元素,这样做是否符合 visited() 应该返回一个 set[Point] 的严格类型检查规则?

这是我的代码:

from dataclasses import dataclass

@dataclass
class Point:
  x:int
  y:int
  z:int
  

#regular class
class Movements:
  def __init__(self, p: Point):
    self.m = p  
    all_points_visited = set()
    all_points_visited.add(Point(self.m.x, self.m.y, self.m.z))
    
  def move_to(self, dir: str) -> None:
    if dir[1] == "x":
      if dir[0] == "-":
        self.m.x -= 1
        
      elif dir[0] == "+":
        self.m.x += 1
    
    elif dir[1] == "y":
      if dir[0] == "-":
        self.m.y -= 1
        
      elif dir[0] == "+":
        self.m.y += 1
    
    elif dir[1] == "z":
      if dir[0] == "-":
        self.m.z -= 1
        
      elif dir[0] == "+":
        self.m.z += 1
        
        
    all_points_visited.add(Point(self.m.x, self.m.y, self.m.z))
        
        
  def teleport_to(self, point: Point) -> None:
    self.m.x = point.x
    self.m.y = point.y 
    self.m.z = point.z  
    
    all_points_visited.add(Point(self.m.x, self.m.y, self.m.z))
    
  def visited() -> set[Point]:
    return all_points_visited
    
  def visited_count() -> int:
    return len(all_points_visited)


#test
m = Movements(Point(10, -30, 0))

1 个回答

1

把你的 Point 变成不可变的,只需设置 frozen=True,这样它们就可以被哈希了:

import dataclasses
from dataclasses import dataclass


@dataclass(frozen=True)
class Point:
    x: int
    y: int
    z: int

    def move_to(self, dir: str):
        x, y, z = self.x, self.y, self.z
        if dir == "+x":
            x += 1
        elif dir == "-x":
            x -= 1
        elif dir == "+y":
            y += 1
        elif dir == "-y":
            y -= 1
        elif dir == "+z":
            z += 1
        elif dir == "-z":
            z -= 1
        else:
            raise NotImplementedError("Unknown direction")
        return dataclasses.replace(self, x=x, y=y, z=z)


# regular class
class Movements:
    def __init__(self, p: Point):
        self.current = p
        self.all_points_visited = set()
        self.all_points_visited.add(p)

    def move_to(self, dir: str) -> None:
        self.teleport_to(self.current.move_to(dir))

    def teleport_to(self, point: Point) -> None:
        self.current = point
        self.all_points_visited.add(point)

    def visited(self) -> set[Point]:
        return self.all_points_visited

    def visited_count(self) -> int:
        return len(self.all_points_visited)


# test
point = Point(10, -30, 0)
m = Movements(point)
m.move_to("+x")
m.teleport_to(Point(11, 30, 0))
print(m.visited())

撰写回答