回溯找到元素加起来小于K的元素向量

2024-04-19 08:16:27 发布

您现在位置:Python中文网/ 问答频道 /正文

我对下面的问题感兴趣主要是为了获得关于回溯算法的直觉,所以我不是在寻找不使用回溯的替代解决方案。你知道吗

问题:找到所有n个元素向量,使它们的元素之和小于或等于某个数字K。向量中的每个元素都是整数。

示例:如果n=3,K=10,则[9,0,0]和[5,0,5]是解,而[3,1,8]不是。你知道吗

this site开始,我修改了python代码,试图实现一个解决方案。你知道吗

以下是一般的“回溯引擎”功能:

def solve(values, safe_up_to, size):

    solution = [None] * size

    def extend_solution(position):
        for value in values:
            solution[position] = value
            if safe_up_to(solution, position):
                if position >= size-1 or extend_solution(position+1):
                    return solution
        return None

    return extend_solution(0)

下面是检查解决方案是否“安全”的函数:

def safe_up_to(partial_solution, target = 100): 
   partial_solution = np.array(partial_solution)  # convert to np array 

   # replace None with NaN
   partial_solution = np.where(partial_solution == None, np.nan, partial_solution)

   if np.nansum(partial_solution) <= target: 
       return True
   else: 
       return False 

然而,当我同时运行这两个函数时,我只得到一个全为零的向量。你知道吗

solve(values=range(10), safe_up_to=safe_up_to, size=5)

我应该如何修改此代码以获得所有可行的解决方案?你知道吗


Tags: tonone元素sizereturndefnpposition
1条回答
网友
1楼 · 发布于 2024-04-19 08:16:27

下面是您的代码的一个稍微修改过的版本。我试着让它尽可能少的改变:

import numpy as np
from functools import partial

def solve(values, safe_up_to, size):

    solution = [None] * size

    def extend_solution(position):
        for value in values:
            solution[position] = value
            if safe_up_to(solution):
                if position >= size-1:
                    yield np.array(solution)
                else:
                    yield from extend_solution(position+1)
        solution[position] = None

    return extend_solution(0)

def safe_up_to(target, partial_solution): 
   partial_solution = np.array(partial_solution)  # convert to np array 

   # replace None with NaN
   partial_solution = np.where(partial_solution == None, np.nan, partial_solution)

   if np.nansum(partial_solution) <= target: 
       return True
   else: 
       return False 

for sol in solve(values=range(10), safe_up_to=partial(safe_up_to,4), size=2):
    print(sol,sol.sum())

印刷品:

[0 0] 0
[0 1] 1
[0 2] 2
[0 3] 3
[0 4] 4
[1 0] 1
[1 1] 2
[1 2] 3
[1 3] 4
[2 0] 2
[2 1] 3
[2 2] 4
[3 0] 3
[3 1] 4
[4 0] 4

相关问题 更多 >