螺丝螺母 nuts and bolts

112 阅读1分钟
def partition(arr, pivot, compare_func):
    i = 0
    for j in range(len(arr)):
        if compare_func(arr[j], pivot):
            # 大于pivot可以不用动,
            # 小于的化
            arr[i], arr[j] = arr[j], arr[i]
            i += 1
    for j in range(i, len(arr)):
        if arr[j] == pivot:
            arr[i], arr[j] = arr[j], arr[i]
            break
    return i


def match_nuts_and_bolts(nuts, bolts):
    def match_pairs(left, right):
        if left < right:
            pivot_bolt = bolts[right]
            partition_index = partition(nuts, pivot_bolt, lambda nut, bolt: nut < bolt)
            print(partition_index)
            pivot_nut = nuts[partition_index]
            partition(bolts, pivot_nut, lambda bolt, nut: bolt < nut)
            print(nuts, bolts)
            match_pairs(left, partition_index - 1)
            match_pairs(partition_index + 1, right)

    if len(nuts) != len(bolts):
        raise ValueError("xxx")
    match_pairs(0, len(nuts) - 1)
    return nuts, bolts

nuts = [1,3,2,4,5]
bolts = [4,2,3,1,5]

matched_nuts, matched_bolts = match_nuts_and_bolts(nuts, bolts)
print(matched_nuts)
print(matched_bolts)