非重叠区间对的组合
我最近参加了一个编程挑战,任务是计算在给定的起始点和结束点列表中,不重叠的独特区间对的数量。我想出了一个 n^2 的解决方案,并通过使用集合来去除重复项,集合会把每个 (起始, 结束) 的组合存储起来。我在想,是否有更高效的方法来解决这个问题,或者这就是我能做到的最好方法:
def paperCuttings(starting, ending):
# Pair each start with its corresponding end and sort
intervals = sorted(zip(starting, ending), key=lambda x: x[1])
non_overlaps = set()
print(intervals)
# Store valid combinations
for i in range(len(intervals)):
for j in range(i+1, len(intervals)):
# If the ending of the first is less than the starting of the second, they do not overlap
if intervals[i][1] < intervals[j][0]:
non_overlaps.add((intervals[i], intervals[j]))
return len(non_overlaps)
starting = [1,1,6,7]
ending = [5,3,8,10]
print(paperCuttings(starting, ending)) # should return 4
starting2 = [3,1,2,8,8]
ending2 = [5, 3, 7, 10, 10]
print(paperCuttings(starting2, ending2)) # should return 3
我之所以这样问,是因为在一些隐藏的测试用例中,我的代码超时了。
2 个回答
想法:我们可以用 O(max(范围, n)) 的时间来处理这个数组。对于每个区间的开始点,我们要计算已经结束的区间数量。是否把单点交集算作重叠,只需要改动两行代码,所以我加了一个参数来控制这个。
如果范围远大于 n,那这个方法就没什么意义。但如果范围大约是 n log n 或者更好,那这个方法可能会有用。
Ruby 代码:注意:这个代码没有去掉重复的区间;你可以参考 Cary 的答案(不需要排序)或者用其他线性时间的方法来处理。
def count_non_overlapping_intervals(start_points, end_points, count_single_point_overlaps=true)
raise "Arrays must be of equal length" unless start_points.length == end_points.length
n = start_points.length
range = end_points.max - start_points.min + 1
startpoint_to_count = Hash.new(0)
endpoint_to_count = Hash.new(0)
start_points.each { |point| startpoint_to_count[point] += 1 }
end_points.each { |point| endpoint_to_count[point] += 1 }
good_pairs = 0
cum_starts = 0 # cumulative range starts
cum_ends = 0 # cumulatve range ends
(start_points.min..end_points.max).each do |i|
new_starts = startpoint_to_count[i] || 0
new_ends = endpoint_to_count[i] || 0
cum_starts += new_starts
cum_ends += new_ends if count_single_point_overlaps
good_pairs += new_starts * cum_ends
cum_ends += new_ends unless count_single_point_overlaps
end
return good_pairs
end
# Example
start_points = [3, 1, 2, 8]
end_points = [5, 3, 7, 10]
>count_non_overlapping_intervals(start_points, end_points, true)
=> 4
> count_non_overlapping_intervals(start_points, end_points, false)
=> 3
对于上面的例子,我们给我们的区间命名:a=[1,3],b=[3,5],c=[2,7],d=[8,10]。
那么,我们的有效区间对是 ad、bd、cd,可能还有 ab。
这是一个在Ruby中实现的O(n*log n)的解决方案(n是区间的数量)。我会提供一个详细的解释,这样你就可以很容易地把代码转换成Python。
我假设不重叠的区间之间没有任何共同点,连端点都没有1。
def paperCuttings(starting, ending)
# Compute an array of unique intervals sorted by the beginning
# of each interval
intervals = starting.zip(ending).uniq.sort
n = intervals.size
count = 0
# Loop over the indices of all but the last interval.
# The interval at index i of intervals will be referred to
# below as "interval i"
(0..n-2).each do |i|
# intervals[i] is interval i, an array containing its two
# endpoints. Extract the second endpoint to the variable endpoint
_, endpoint = intervals[i]
# Employ a binary search to find the index of the first
# interval j > i for which intervals[j].first > endpoint,
# where intervals[j].first is the beginning of interval j
k = (i+1..n-1).bsearch { |j| intervals[j].first > endpoint }
# k equals nil if no such interval is found, in which case
# continue the loop the next interval i
next if k == nil
# As intervals i and k are non-overlapping, interval i is
# non-overlapping with all intervals l, k <=l<= n-1, of which
# there are n-k, so add n-k to count
count = count + n - k
end
# return count
count
end
试试看。
starting = [1, 1, 6, 7]
ending = [5, 3, 8, 10]
paperCuttings(starting, ending)
#=> 4
starting = [3, 1, 2, 8, 8]
ending = [5, 3, 7, 10, 10]
paperCuttings(starting, ending)
#=> 3
接下来我会解释计算的过程。
intervals = starting.zip(ending).uniq.sort
对于
starting = [3, 1, 2, 8, 8]
ending = [5, 3, 7, 10, 10]
a = starting.zip(ending)
#=> [[3, 5], [1, 3], [2, 7], [8, 10], [8, 10]]
b = a.uniq
#=> [[3, 5], [1, 3], [2, 7], [8, 10]]
b.sort
#=> [[1, 3], [2, 7], [3, 5], [8, 10]]
去除重复项是题目要求的。
数组b的元素是按照它们的第一个元素进行排序的。如果有两个数组的第一个元素相同,那么就会用第二个元素来决定顺序,不过在这里这并不重要。
Ruby的二分查找方法(在一个范围内)的文档可以在这里找到。二分查找的时间复杂度是O(log n),这就是为什么整体时间复杂度是O(n*log n)中有log这一项的原因。
1. 如果认为只共享一个端点的区间也是不重叠的,那么需要把starting[j] >= endpoint
改成starting[j] > endpoint
。