如何优化我的BFS实施?

2024-04-25 21:58:46 发布

您现在位置:Python中文网/ 问答频道 /正文

我是一个python初学者,试图在cs50上解决一些项目。 我做了一个项目。 这是来自IMDB的一些电影数据的BFS实现。 我成功地解决了它。 但是我的代码对于一些非常远的输入要花费5分钟以上。对于图中的一些非常远的节点来说,是指非常远的节点。 这是我的密码。 忽略加载数据部分,重点关注最短路径方法

import csv
import sys


class Node():
    def __init__(self, state, parent, action):
        self.state = state
        self.parent = parent
        self.action = action
    def __str__(self):
        print(self.state)


# Maps names to a set of corresponding person_ids
names = {}

# Maps person_ids to a dictionary of: name, birth, movies (a set of movie_ids)
people = {}

# Maps movie_ids to a dictionary of: title, year, stars (a set of person_ids)
movies = {}


def load_data(directory):
    """
    Load data from CSV files into memory.
    """
    # Load people
    with open(f"{directory}/people.csv", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            people[row["id"]] = {
                "name": row["name"],
                "birth": row["birth"],
                "movies": set()
            }
            if row["name"].lower() not in names:
                names[row["name"].lower()] = {row["id"]}
            else:
                names[row["name"].lower()].add(row["id"])

    # Load movies
    with open(f"{directory}/movies.csv", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            movies[row["id"]] = {
                "title": row["title"],
                "year": row["year"],
                "stars": set()
            }

    # Load stars
    with open(f"{directory}/stars.csv", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            try:
                people[row["person_id"]]["movies"].add(row["movie_id"])
                movies[row["movie_id"]]["stars"].add(row["person_id"])
            except KeyError:
                pass


def main():
    if len(sys.argv) > 2:
        sys.exit("Usage: python degrees.py [directory]")
    directory = sys.argv[1] if len(sys.argv) == 2 else "large"

    # Load data from files into memory
    print("Loading data...")
    load_data(directory)
    print("Data loaded.")

    source = person_id_for_name(input("Name: "))
    if source is None:
        sys.exit("Person not found.")
    target = person_id_for_name(input("Name: "))
    if target is None:
        sys.exit("Person not found.")

    path = shortest_path(source, target)

    if path is None:
        print("Not connected.")
    else:
        degrees = len(path)
        print(f"{degrees} degrees of separation.")
        path = [(None, source)] + path
        for i in range(degrees):
            person1 = people[path[i][1]]["name"]
            person2 = people[path[i + 1][1]]["name"]
            movie = movies[path[i + 1][0]]["title"]
            print(f"{i + 1}: {person1} and {person2} starred in {movie}")


def shortest_path(source, target):
    """
    Returns the shortest list of (movie_id, person_id) pairs
    that connect the source to the target.

    If no possible path, returns None.
    """
    source = Node(source,None,None)
    queue = list(((Node(neighbor,source,None) for neighbor in neighbors_for_person(source.state) )))
    explored = list()
    path = list()
    while queue:
        current_node = queue.pop(0)
        if current_node.state not in explored:
            if current_node.state[1] == target:
                print('Goal reached')
                goal = current_node
                while goal:
                    path.append(goal.state)
                    goal = goal.parent
                return path[:-1][::-1]
            explored.append(current_node.state)
        queue.extend((Node(neighbor,current_node,None) for neighbor in neighbors_for_person(current_node.state[1])))





def person_id_for_name(name):
    """
    Returns the IMDB id for a person's name,
    resolving ambiguities as needed.
    """
    person_ids = list(names.get(name.lower(), set()))
    if len(person_ids) == 0:
        return None
    elif len(person_ids) > 1:
        print(f"Which '{name}'?")
        for person_id in person_ids:
            person = people[person_id]
            name = person["name"]
            birth = person["birth"]
            print(f"ID: {person_id}, Name: {name}, Birth: {birth}")
        try:
            person_id = input("Intended Person ID: ")
            if person_id in person_ids:
                return person_id
        except ValueError:
            pass
        return None
    else:
        return person_ids[0]


def neighbors_for_person(person_id):
    """
    Returns (movie_id, person_id) pairs for people
    who starred with a given person.
    """
    movie_ids = people[person_id]["movies"]
    neighbors = set()
    for movie_id in movie_ids:
        for person_id in movies[movie_id]["stars"]:
            neighbors.add((movie_id, person_id))
    return neighbors


if __name__ == "__main__":
    main()

我该怎么做才能让它跑得快


Tags: pathnameinnoneididssourcefor
1条回答
网友
1楼 · 发布于 2024-04-25 21:58:46

一些提示:

  • not in explored的时间复杂度为O(n)(其中nexplored的大小),因为explored是一个列表。而是将其设置为一个集合,以便not in explored以恒定时间执行:

    explored = set()
    # ...
    if current_node.state not in explored:
    # ...
    explored.add(current_node.state)
    
    
  • queue.pop(0)的时间复杂度为O(n)(其中nqueue的大小),因为queue是一个列表。相反,将其设为deque,以便queue.popleft()以恒定时间执行:

    from collections import deque
    queue = deque()
    # ...
    current_node = queue.popleft()
    # ...
    queue.extend((Node(neighbor,current_node,None) for neighbor in neighbors_for_person(current_node.state[1])))
    

相关问题 更多 >