寻找最佳剪裁圆圈
给定一个 NxN
的整数格子,我想找到一个被裁剪的圆,使得它内部的格点值的总和最大。
每个格点 (i,j)
都有一个值 V(i,j)
,这些值存储在下面的矩阵 V
中:
[[ 1, 1, -3, 0, 0, 3, -1, 3, -3, 2],
[-2, -1, 0, 1, 0, -2, 0, 0, 1, -3],
[ 2, 2, -3, 2, -2, -1, 2, 2, -2, 0],
[-2, 0, -3, 3, 0, 2, -1, 1, 3, 3],
[-1, -2, -1, 2, 3, 3, -3, -3, 2, 0],
[-3, 3, 2, 0, -3, -2, -1, -3, 0, -3],
[ 3, 2, 2, -1, 0, -3, 1, 1, -2, 2],
[-3, 1, 3, 3, 0, -3, -3, 2, -2, 1],
[ 0, -3, 0, 3, 2, -2, 3, -2, 3, 3],
[-1, 3, -3, -2, 0, -1, -2, -1, -1, 2]]
我们的目标是最大化在一个半径为 R
的(裁剪的)圆内和边界上的格点值 V(i,j)
的总和,前提和条件如下:
- 圆心在 (0,0) 这个点
- 圆的半径可以是任何正数(不一定是整数半径,可以是有理数)。
- 圆可能会在两个格点处被裁剪,形成一条对角线,如图所示。这条对角线的斜率是 -45 度。
一些额外的细节:
裁剪圆的得分是所有在圆内(或在边界上)并且在包括 (0,0) 的对角线一侧的整数值的总和。边界上的值(或附近的值)有 -3, 1, 3, -1, -3, 3, -1, 2, 0, 3。
虽然圆的半径可以是任何值,但我们只需要考虑与格点精确相交的圆,因此有 n^2 种不同的相关半径。此外,我们只需记录圆与对角线相交的一个位置,就可以完全确定这个裁剪的圆。注意,这个与对角线的交点不需要是整数坐标。
如果最优解的圆根本没有被对角线裁剪,那么我们只需要返回圆的半径。
到目前为止我发现的:
如果我们只想找到最优的圆,可以通过以下方法快速完成,时间与输入大小成正比:
import numpy as np
from math import sqrt
np.random.seed(40)
def find_max(A):
n = A.shape[0]
sum_dist = np.zeros(2 * n * n, dtype=np.int32)
for i in range(n):
for j in range(n):
dist = i**2 + j**2
sum_dist[dist] += A[i, j]
cusum = np.cumsum(sum_dist)
# returns optimal radius with its score
return sqrt(np.argmax(cusum)), np.max(cusum)
A = np.random.randint(-3, 4, (10, 10))
print(find_max(A))
那么,找到最优裁剪圆的速度有多快呢?
2 个回答
显而易见的解决办法是 O(N^2 log N)
,也就是说我们需要计算从 (0, 0)
到每个网格点的距离,然后把这些距离按从小到大的顺序排列,再遍历这个排好序的列表,记录累积的总和,最后保存最大的总和(这就像是逐渐“扩展”一个圆圈)。蓝色的对角线只是一个额外的限制条件。
不过,我有一种强烈的感觉,这个问题可以在 O(N^2)
的时间内解决,而不需要排序这一步(排序会增加 log N
的复杂度)。从 (0,0)
出发,应该有一种明显的遍历顺序,可以让我们按照距离原点的 L2 距离来访问网格点,而且是线性时间完成,而不需要排序。
首先,创建一个累积频率表,或者叫做芬威克树。你需要为每个圆的半径记录一个数据,值对应于从原点到这个距离的探索权重。接下来,从原点开始进行广度优先搜索(BFS)。
对于每一个对角线的“边界”,你需要用半径和权重的键值对来更新你的表或树(把权重加到已有的值上)。然后,你还需要查询表或树,获取刚刚添加的每个半径的当前累积和,注意最大值,并相应地更新一个全局的最大值。
一旦你的搜索结束,你就会得到被裁剪圆的最大和。如果你想重建这个圆,只需存储最大半径和BFS的深度,以及全局最大和本身。
这样做的时间复杂度是 O(N^2 log N)
,因为会有 N^2 次更新和查询,每次都是 O(log N)
。
这个解决方案的直觉是,通过沿着这个对角线的“边界”向外探索,你隐式地裁剪了所有你查询的圆,因为它右边和上面的权重还没有被添加。通过计算每个搜索深度刚刚更新的半径的最大值,你也确保了这些圆在裁剪线与整数坐标相交。
更新
这里有一段Python代码展示了这个过程。代码需要整理一下,但至少它展示了这个过程。我选择使用累积频率/最大数组,而不是树,因为这可能更适合用numpy进行向量化处理。
def solve(matrix):
n = len(matrix)
max_radius_sqr = 2 * (n - 1) ** 2
num_bins = max_radius_sqr.bit_length() + 1
frontier = [(0, 0)]
csum_arr = [[0] * 2 ** i for i in range(num_bins)[::-1]]
cmax_arr = [[0] * 2 ** i for i in range(num_bins)[::-1]]
max_csum = -float("inf")
max_csum_depth = None
max_csum_radius_sqr = None
depth = 0
while frontier:
next_frontier = []
if depth + 1 < n: # BFS up
next_frontier.append((0, depth + 1))
# explore frontier, updating csums and maximums per each
for x, y in frontier:
if x + 1 < n: # BFS right
next_frontier.append((x + 1, y))
index = x ** 2 + y ** 2 # index is initially the radius squared
for i in range(num_bins):
csum_arr[i][index] += matrix[y][x] # update csums
if i != 0: # skip first, since no children to take max of
sum_left = csum_arr[i-1][index << 1] # left/right is tree notation of the array
max_left = cmax_arr[i-1][index << 1]
max_right = cmax_arr[i-1][index << 1 | 1]
cmax_arr[i][index] = max(max_left, sum_left + max_right) # update csum maximums
index >>= 1 # shift off last bit, update sums/maxs again, log2 times
# after entire frontier is explored, query for overall max csum over all radii
# update running global max and associated values
if cmax_arr[-1][0] > max_csum:
max_csum = cmax_arr[-1][0]
max_csum_depth = depth
index = 0
for i in range(num_bins-1)[::-1]: # reconstruct max radius (this could just as well be stored)
sum_left = csum_arr[i][index << 1]
max_left = cmax_arr[i][index << 1]
max_right = cmax_arr[i][index << 1 | 1]
index <<= 1
if sum_left + max_right > max_left:
index |= 1
max_csum_radius_sqr = index
depth += 1
frontier = next_frontier
# total max sum, dx + dy of diagonal cut, radius ** 2
return max_csum, max_csum_depth, max_csum_radius_sqr
用给定的测试案例调用这个代码会产生预期的输出:
matrix = [
[-1, 3, -3, -2, 0, -1, -2, -1, -1, 2],
[ 0, -3, 0, 3, 2, -2, 3, -2, 3, 3],
[-3, 1, 3, 3, 0, -3, -3, 2, -2, 1],
[ 3, 2, 2, -1, 0, -3, 1, 1, -2, 2],
[-3, 3, 2, 0, -3, -2, -1, -3, 0, -3],
[-1, -2, -1, 2, 3, 3, -3, -3, 2, 0],
[-2, 0, -3, 3, 0, 2, -1, 1, 3, 3],
[ 2, 2, -3, 2, -2, -1, 2, 2, -2, 0],
[-2, -1, 0, 1, 0, -2, 0, 0, 1, -3],
[ 1, 1, -3, 0, 0, 3, -1, 3, -3, 2],
][::-1]
print(solve(matrix))
# output: 13 9 54
换句话说,它表示最大总和是 13
,对角线切割的偏移量(dx + dy)是 9
,半径的平方是 54
。
如果今晚或这个周末有时间,我会把代码整理得更好一些。