给定一个多维数组 A,如何高效地统计其中相同子数组的数量?
例如,考虑以下 2x4 数组 A:
A = [[2, 3, 5, 7],
[2, 3, 5, 7],
[1, 7, 1, 4],
[5, 8, 6, 0],
[2, 3, 5, 7]]
第一行、第二行和最后一行是相同的。我们需要一个算法来统计每种不同行的相同行数(即每种元素的重复次数)。如果该算法还能够轻松修改以统计相同列的数量,那就更好了。
目前,我们使用了一种低效的朴素算法来实现该功能:
import numpy
A = numpy.array([[2, 3, 5, 7],
[2, 3, 5, 7],
[1, 7, 1, 4],
[5, 8, 6, 0],
[2, 3, 5, 7]])
i = 0
end = len(A)
while i < end:
print(i, end=" ")
j = i + 1
numberID = 1
while j < end:
print(j, end=" ")
if numpy.array_equal(A[i, :], A[j, :]):
numberID += 1
j += 1
i += 1
print(A, len(A))
这种算法看起来像是使用 numpy 中的原生 Python,因此效率低下。
2. 解决方案
2.1 使用 NumPy 的 unique 函数
在 NumPy >= 1.9.0 中,np.unique 函数具有一个 return_counts 关键字参数,可以与此处提供的解决方案结合使用,以获取子数组的出现次数:
import numpy as np
b = np.ascontiguousarray(A).view(np.dtype((np.void, A.dtype.itemsize * A.shape[1])))
unq_a, unq_cnt = np.unique(b, return_counts=True)
unq_a = unq_a.view(A.dtype).reshape(-1, A.shape[1])
print(unq_a, unq_cnt)
这将输出:
[[1 7 1 4]
[2 3 5 7]
[5 8 6 0]] [1 3 1]
其中 unq_a 是唯一子数组,unq_cnt 是每个子数组出现的次数。
如果使用较旧版本的 NumPy,可以复制 np.unique 的功能,如下所示:
a_view = np.array(A, copy=True)
a_view = a_view.view(np.dtype((np.void,
a_view.dtype.itemsize * a_view.shape[1]))).ravel()
a_view.sort()
a_flag = np.concatenate(([True], a_view[1:] != a_view[:-1]))
a_unq = A[a_flag]
a_idx = np.concatenate(np.nonzero(a_flag) + ([a_view.size],))
a_cnt = np.diff(a_idx)
print(a_unq, a_cnt)
这将输出与上述相同的结果。
2.2 使用 NumPy 的 lexsort 函数
可以使用 NumPy 的 lexsort 函数对行进行排序,从而将搜索复杂度从 O(n^2) 降低到 O(n)。需要注意的是,默认情况下,最后一列中的元素最后排序,即行的“按字母顺序排列”是从右到左,而不是从左到右。
import numpy as np
print(a)
sorted_indices = np.lexsort(a.T)
sorted_a = a[sorted_indices]
print(sorted_a)
这将输出:
[[5 8 6 0]
[1 7 1 4]
[2 3 5 7]
[2 3 5 7]
[2 3 5 7]]
现在,我们可以遍历排序后的数组并统计相同子数组的数量。
2.3 使用 collections.Counter 类
也可以使用 collections.Counter 类来统计相同子数组的数量。它可以像这样工作:
from collections import Counter
x = [(2, 3, 5, 7), (2, 3, 5, 7), (1, 7, 1, 4), (5, 8, 6, 0), (2, 3, 5, 7)]
c = Counter(x)
print(c)
这将输出:
Counter({(2, 3, 5, 7): 3, (5, 8, 6, 0): 1, (1, 7, 1, 4): 1})
需要注意的是,x 中的每个值本身都是一个列表,它是不可散列的数据结构。如果可以将 x 中的每个值转换为元组,那么它应该可以正常工作:
from collections import Counter
x = [(2, 3, 5, 7), (2, 3, 5, 7), (1, 7, 1, 4), (5, 8, 6, 0), (2, 3, 5, 7)]
c = Counter(x)
print(c)
这将输出:
Counter({(2, 3, 5, 7): 3, (5, 8, 6, 0): 1, (1, 7, 1, 4): 1})