高效统计多维数组中相同子数组的数量

69 阅读3分钟

给定一个多维数组 A,如何高效地统计其中相同子数组的数量?

例如,考虑以下 2x4 数组 A:

A = [[2,  3,  5,  7],
     [2,  3,  5,  7],
     [1,  7,  1,  4],
     [5,  8,  6,  0],
     [2,  3,  5,  7]]

第一行、第二行和最后一行是相同的。我们需要一个算法来统计每种不同行的相同行数(即每种元素的重复次数)。如果该算法还能够轻松修改以统计相同列的数量,那就更好了。

目前,我们使用了一种低效的朴素算法来实现该功能:

import numpy

A = numpy.array([[2,  3,  5,  7],
                 [2,  3,  5,  7],
                 [1,  7,  1,  4],
                 [5,  8,  6,  0],
                 [2,  3,  5,  7]])

i = 0
end = len(A)

while i < end:
    print(i, end=" ")
    j = i + 1
    numberID = 1

    while j < end:
        print(j, end=" ")

        if numpy.array_equal(A[i, :], A[j, :]):
            numberID += 1

        j += 1

    i += 1

print(A, len(A))

这种算法看起来像是使用 numpy 中的原生 Python,因此效率低下。

2. 解决方案

2.1 使用 NumPy 的 unique 函数

在 NumPy >= 1.9.0 中,np.unique 函数具有一个 return_counts 关键字参数,可以与此处提供的解决方案结合使用,以获取子数组的出现次数:

import numpy as np

b = np.ascontiguousarray(A).view(np.dtype((np.void, A.dtype.itemsize * A.shape[1])))
unq_a, unq_cnt = np.unique(b, return_counts=True)
unq_a = unq_a.view(A.dtype).reshape(-1, A.shape[1])

print(unq_a, unq_cnt)

这将输出:

[[1 7 1 4]
 [2 3 5 7]
 [5 8 6 0]] [1 3 1]

其中 unq_a 是唯一子数组,unq_cnt 是每个子数组出现的次数。

如果使用较旧版本的 NumPy,可以复制 np.unique 的功能,如下所示:

a_view = np.array(A, copy=True)
a_view = a_view.view(np.dtype((np.void,
                               a_view.dtype.itemsize * a_view.shape[1]))).ravel()
a_view.sort()
a_flag = np.concatenate(([True], a_view[1:] != a_view[:-1]))
a_unq = A[a_flag]
a_idx = np.concatenate(np.nonzero(a_flag) + ([a_view.size],))
a_cnt = np.diff(a_idx)

print(a_unq, a_cnt)

这将输出与上述相同的结果。

2.2 使用 NumPy 的 lexsort 函数

可以使用 NumPy 的 lexsort 函数对行进行排序,从而将搜索复杂度从 O(n^2) 降低到 O(n)。需要注意的是,默认情况下,最后一列中的元素最后排序,即行的“按字母顺序排列”是从右到左,而不是从左到右。

import numpy as np

print(a)
sorted_indices = np.lexsort(a.T)
sorted_a = a[sorted_indices]

print(sorted_a)

这将输出:

[[5 8 6 0]
 [1 7 1 4]
 [2 3 5 7]
 [2 3 5 7]
 [2 3 5 7]]

现在,我们可以遍历排序后的数组并统计相同子数组的数量。

2.3 使用 collections.Counter

也可以使用 collections.Counter 类来统计相同子数组的数量。它可以像这样工作:

from collections import Counter

x = [(2, 3, 5, 7), (2, 3, 5, 7), (1, 7, 1, 4), (5, 8, 6, 0), (2, 3, 5, 7)]
c = Counter(x)

print(c)

这将输出:

Counter({(2, 3, 5, 7): 3, (5, 8, 6, 0): 1, (1, 7, 1, 4): 1})

需要注意的是,x 中的每个值本身都是一个列表,它是不可散列的数据结构。如果可以将 x 中的每个值转换为元组,那么它应该可以正常工作:

from collections import Counter

x = [(2, 3, 5, 7), (2, 3, 5, 7), (1, 7, 1, 4), (5, 8, 6, 0), (2, 3, 5, 7)]
c = Counter(x)

print(c)

这将输出:

Counter({(2, 3, 5, 7): 3, (5, 8, 6, 0): 1, (1, 7, 1, 4): 1})