每个查询的未流通股份数

1 投票
2 回答
63 浏览
提问于 2025-04-14 15:49

我遇到了一个问题,虽然我解决了它,但效率不高。我的方法是简单的o(m*n)算法,其中m是查询的数量,n是订单的数量。

你会得到一个订单日志(std:vector),这是我们交易系统在一天内发送的订单记录。每个订单有以下几个属性:

  • order_token:一个唯一的整数,用来标识这个订单
  • shares:要买或卖的股票数量
  • price:每股买入或卖出的价格
  • side:false表示卖出,true表示买入
  • created at:订单创建的时间戳
  • cancelled_or_executed_at:订单被取消或执行(完成)的时间戳

一个订单在时间区间[created_at, cancelled_or_executed_at)内是有效的。也就是说,created_at是包含在内的,而cancelled_or_executed_at是不包含在内的。时间戳用整数表示,比如从午夜开始的毫秒数。你可以假设每个订单都是完全被取消或执行的。

除了订单,你还会得到一个查询的向量。每个查询有一个字段:query_time,一个时间戳。查询的答案是所有在查询时间有效的订单中,未完成的股票总数。未完成的股票数量是累加的,比如一个买入10股和一个卖出10股的订单,加起来就是20股有效。

我在想有没有人能用其他数据结构或算法来优化我下面的解决方案。我相信是有的。这是一个C++的问题,但为了方便我用Python做了我的解决方案。

def calculate_outstanding_shares(orders, queries):
    result = {}

    for query in queries:
        live_trades = 0
        for order in orders:
            if query > order[4] and query < order[5]:
                live_trades += order[1]
                
        result[query] = live_trades

    return result


# Example usage
orders = [
    [3, 15, 200, True, 2000, 4000],
    [1,10,100,True,0,5000],
    [4, 25, 250, False, 2500, 6000],
    [2,20,150,False,1000,3000],
]

queries = [
    500,  # Before any order
    1500,  # Inside the duration of the first buy order
    2500,  # Inside the duration of both buy orders and the start of the second sell order
    3500,  # Inside the duration of the second sell order and after the first sell order ends
    5500  # After all orders have ended
]

result = calculate_outstanding_shares(orders, queries)
print(result)

2 个回答

1

我觉得用Ruby写个答案可能会很有用,这个答案可以看作伪代码,方便大家用任何语言来改编。

orders = [
  [3, 15, 200,  'True', 2000, 4000],
  [1, 10, 100,  'True',    0, 5000],
  [4, 25, 250, 'False', 2500, 6000],
  [2, 20, 150, 'False', 1000, 3000],
]

第一步:保存起始余额,并创建一个哈希表,哈希表的键是日期(比如下单日期或清算日期),值是对应的股票数量

def add_to_epochs(epochs, date, amount)
  if epochs.has_key?(date) == false
    epochs[date] = 0
  end
  epochs[date] = epochs[date] + amount
end
balance = 0
epochs = {}
orders.each do |_, shares, _, _, place_date, clear_date|
  if place_date == 0
    balance = shares
  else
    add_to_epochs(epochs, place_date, shares)
  end
  add_to_epochs(epochs, clear_date, -shares)
end
balance
  #=> 10
epochs
  #=> {2000=> 15, 4000=>-15, 5000=>-10, 2500=>25,
  #    6000=>-25, 1000=> 20, 3000=>-20}

第二步:创建一个排序后的日期数组,并把epoch值转换成在该时间段内的活跃订单数量

sorted_dates = epochs.keys.sort
  #=> [1000, 2000, 2500, 3000, 4000, 5000, 6000]
sorted_dates.map do |d|
  new_balance = balance + epochs[d]
  epochs[d] = balance
  balance = new_balance
end
epochs
  #=> {2000=>30, 4000=>50, 5000=>35, 2500=>45, 6000=>25, 1000=>10, 3000=>70}

第三步:对于每个查询,使用二分查找来计算相关时间段的结束日期,并返回该日期的epoch

举个例子,如果查询是3500,那么二分查找会返回sorted_dates中最小的值d,使得d >= 3500

sorted_dates.bsearch { |d| d >= 3500 }
  #=> 4000
epochs[4000]
  #=> 50 live orders at time 3500

几乎所有编程语言都有方法或函数可以进行二分查找,编写一个二分查找的代码也很简单。二分查找的时间效率是(log n)。

现在我们可以计算每个queries元素的活跃订单数量:

queries.each do |q|
  alive = epochs[sorted_dates.bsearch { |d| d >= q }]
  puts "#{q}:  #{alive}"
end

这将显示以下内容。

 500:  10
1500:  30
2500:  45
3500:  50
5500:  25
2

你可以先处理一下 orders 这个列表,然后在每个时间点(开始和结束)计算一下实际的未完成股票数量。

接着,对于每个查询,可以使用 bisect 模块来查找正确的时间位置。这样你只需要进行 m 乘以 log n 次查询。

from bisect import bisect


def calculate(orders, queries):
    outstanding_shares = []

    for a, b, c, d, e, f in orders:
        outstanding_shares.append((e, b))
        outstanding_shares.append((f, -b))

    outstanding_shares.sort()

    c = 0
    outstanding_shares = [(t, (c := c + amt)) for t, amt in outstanding_shares]

    return {
        q: outstanding_shares[bisect(outstanding_shares, (q,)) - 1][1] for q in queries
    }


# Example usage
orders = [
    [3, 15, 200, True, 2000, 4000],
    [1, 10, 100, True, 0, 5000],
    [4, 25, 250, False, 2500, 6000],
    [2, 20, 150, False, 1000, 3000],
]

queries = [
    500,  # Before any order
    1500,  # Inside the duration of the first buy order
    2500,  # Inside the duration of both buy orders and the start of the second sell order
    3500,  # Inside the duration of the second sell order and after the first sell order ends
    5500,  # After all orders have ended
]

print(calculate(orders, queries))

输出结果:

{500: 10, 1500: 30, 2500: 45, 3500: 50, 5500: 25}

撰写回答