更高效的交点计数方法?
我有一个包含30万个列表的集合(纤维轨迹),每个轨迹都是一组(x,y,z)坐标的列表:
tracks=
[[(1,2,3),(3,2,4),...]
[(4,2,1),(5,7,3),...]
...
]
我还有一组掩膜,每个掩膜也是由(x,y,z)坐标组成的列表:
mask_coords_list=
[[(1,2,3),(8,13,4),...]
[(6,2,2),(5,7,3),...]
...
]
我想要找出所有可能的掩膜对之间:
- 与每对掩膜相交的轨迹数量(用来创建一个连接矩阵)
- 与每个掩膜相交的轨迹子集,以便对该子集中的每个轨迹的(x,y,z)坐标加1(用来创建一个“密度”图像)
我目前是这样做第一部分的:
def mask_connectivity_matrix(tracks,masks,masks_coords_list):
connect_mat=zeros((len(masks),len(masks)))
for track in tracks:
cur=[]
for count,mask_coords in enumerate(masks_coords_list):
if any(set(track) & set(mask_coords)):
cur.append(count)
for x,y in list(itertools.combinations(cur,2)):
connect_mat[x,y] += 1
第二部分是这样做的:
def mask_tracks(tracks,masks,masks_coords_list):
vox_tracks_img=zeros((xdim,ydim,zdim,len(masks)))
for track in tracks:
for count,mask in enumerate(masks_coords_list):
if any(set(track) & set(mask)):
for x,y,z in track:
vox_tracks_img[x,y,z,count] += 1
使用集合来查找交集大大加快了这个过程,但当我有70个或更多掩膜时,这两个部分仍然需要超过一个小时。有没有比逐个遍历每条轨迹更有效的方法呢?
6 个回答
你可以先把这两个功能合并在一起,这样就能一次性得到两个结果。另外,在循环之前其实不需要先列出所有组合,因为它本身就是一个生成器,这样可以节省一些时间。
def mask_connectivity_matrix_and_tracks(tracks,masks,masks_coords_list):
connect_mat=zeros((len(masks),len(masks)))
vox_tracks_img=zeros((xdim,ydim,zdim,len(masks)))
for track in tracks:
cur=[]
for count,mask_coords in enumerate(masks_coords_list):
if any(set(track) & set(mask_coords)):
cur.append(count)
for x,y,z in track:
vox_tracks_img[x,y,z,count] += 1
for x,y in itertools.combinations(cur,2):
connect_mat[x,y] += 1
另外,这个过程可能永远不会“快”,也就是说不会在我们死之前完成,所以最好的办法是最终用Cython把它编译成Python的C模块。
好的,我觉得我终于有了一个可以简化复杂度的方法。这段代码相比你现在的代码应该会快很多。
首先,你需要知道哪些轨道和哪些掩码是重合的,这就是所谓的关联矩阵。
import numpy
from collections import defaultdict
def by_point(sets):
d = defaultdict(list)
for i, s in enumerate(sets):
for pt in s:
d[pt].append(i)
return d
def calc(xdim, ydim, zdim, mask_coords_list, tracks):
masks_by_point = by_point(mask_coords_list)
tracks_by_point = by_point(tracks)
a = numpy.zeros((len(mask_coords_list), len(tracks)), dtype=int)
for pt, maskids in masks_by_point.iteritems():
for trackid in tracks_by_point.get(pt, ()):
a[maskids, trackid] = 1
m = numpy.matrix(a)
你需要的邻接矩阵是 m * m.T
。
你现在的代码只计算了上三角部分。你可以用 triu
来获取这一半。
am = m * m.T # calculate adjacency matrix
am = numpy.triu(am, 1) # keep only upper triangle
am = am.A # convert matrix back to array
体素计算也可以使用关联矩阵。
vox_tracks_img = numpy.zeros((xdim, ydim, zdim, len(mask_coords_list)), dtype=int)
for trackid, track in enumerate(tracks):
for x, y, z in track:
vox_tracks_img[x, y, z, :] += a[:,trackid]
return am, vox_tracks_img
对我来说,这段代码在处理几百个掩码和轨道的数据集时,运行时间不到一秒。
如果你有很多点出现在掩码中,但不在任何轨道上,建议在进入循环之前,先把这些点在 masks_by_point
中的条目删除掉,这样可能会更有效。
把体素的坐标整理成一维的形式,然后放进两个叫做 scipy.sparse.sparse.csc 的矩阵里。
这里,v 代表体素的数量,m 代表掩膜的数量,t 代表轨迹的数量。
M 是掩膜的 csc 矩阵,大小是 (m x v),在 (i,j) 这个位置上如果是 1,说明掩膜 i 和体素 j 有重叠。
T 是轨迹的 csc 矩阵,大小是 (t x v),在 (k,j) 这个位置上如果是 1,说明轨迹 k 和体素 j 有重叠。
Overlap = (M * T.transpose() > 0) # track T overlaps mask M
Connected = (Overlap * Overlap.tranpose() > 0) # Connected masks
Density[mask_idx] = numpy.take(T, nonzero(Overlap[mask_idx, :])[0], axis=0).sum(axis=0)
我可能在最后一点上说错了,我不太确定 csc 矩阵能否用非零值和取值操作。你可能需要在一个循环里逐列提取出来,然后转换成完整的矩阵。
我做了一些实验,试图模拟我认为合理的数据量。下面的代码在一台两年旧的 MacBook 上大约需要 2 分钟。如果你使用 csr 矩阵,大约需要 4 分钟。根据每条轨迹的长度,可能会有一些权衡。
from numpy import *
from scipy.sparse import csc_matrix
nvox = 1000000
ntracks = 300000
nmask = 100
# create about 100 entries per track
tcoords = random.uniform(0, ntracks, ntracks * 100).astype(int)
vcoords = random.uniform(0, nvox, ntracks * 100).astype(int)
d = ones(ntracks * 100)
T = csc_matrix((d, vstack((tcoords, vcoords))), shape=(ntracks, nvox), dtype=bool)
# create around 10000 entries per mask
mcoords = random.uniform(0, nmask, nmask * 10000).astype(int)
vcoords = random.uniform(0, nvox, nmask * 10000).astype(int)
d = ones(nmask * 10000)
M = csc_matrix((d, vstack((mcoords, vcoords))), shape=(nmask, nvox), dtype=bool)
Overlap = (M * T.transpose()).astype(bool) # mask M overlaps track T
Connected = (Overlap * Overlap.transpose()).astype(bool) # mask M1 and M2 are connected
Density = Overlap * T.astype(float) # number of tracks overlapping mask M summed across voxels