服务中心的最佳位置(梯度下降)

357 阅读1分钟

原题链接

题意解析

题目更抽象的描述就是,给定点集S,找到一个点A,使得A到S中所有点的欧几里得距离和distance_sum最小,求min_distance_sum。

解法

本质就是个优化问题。 目标函数为f(x,y)=0n1[(xxi)2+(yyi)2)]1/2f(x, y) = \sum _{0}^{n-1}[(x-x_i)^2+(y-y_i)^2)]^{1/2}

对x,y分别求偏导,

f(x)=0n1[(xxi)2+(yyi)2]1/2(xxi)f'(x)=\sum _0^{n-1}[(x-x^i)^2+(y-y^i)^2]^{-1/2}(x-x^i)

f(y)=0n1[(xxi)2+(yyi)2]1/2(yyi)f'(y)=\sum _0^{n-1}[(x-x^i)^2+(y-y^i)^2]^{-1/2}(y-y^i)

梯度下降,更新x, y;

new_x=xalphadx,new_y=yalphadynew\_x = x - alpha * dx, new\_y = y - alpha * dy,alpha是学习率。

同时为了防止振荡,在振荡发生时,调小学习率。

终止条件是更新后的距离和和更新前比较小于题目给定阈值。

代码

from typing import List

class Solution:
    def getMinDistSum(self, positions: List[List[int]]) -> float:
        center = positions[0]
        alpha = 16
        precision_error = 1E-10
        current_res = self.getDistSum(center, positions)
        xs = [p[0] for p in positions]
        ys = [p[1] for p in positions]
        dx, dy = self.getDelta(center[0], xs, center[1], ys), self.getDelta(center[1], ys, center[0], xs)
        while abs(dx) > precision_error or abs(dy) > precision_error:
            dx, dy = self.getDelta(center[0], xs, center[1], ys), self.getDelta(center[1], ys, center[0], xs)
            new_center_x = center[0] - alpha * dx
            new_center_y = center[1] - alpha * dy
            new_center = [new_center_x, new_center_y]
            tmp_res = self.getDistSum(new_center, positions)
            while tmp_res > current_res:
                if abs(tmp_res - current_res) < precision_error:
                    return tmp_res
                alpha /= 2
                new_center_x = center[0] - alpha * self.getDelta(center[0], xs, center[1], ys)
                new_center_y = center[1] - alpha * self.getDelta(center[1], ys, center[0], xs)
                new_center = [new_center_x, new_center_y]
                tmp_res = self.getDistSum(new_center, positions)
            current_res = tmp_res
            center = [new_center_x, new_center_y]
        # print(last_res - current_res)
        return current_res

    def getDelta(self, x, xs, y, ys):
        res = 0
        for i, x_const in enumerate(xs):
            y_const = ys[i]
            divisor = pow(pow(x-x_const, 2) + pow(y-y_const, 2), 0.5)
            if divisor != 0:
                res += (x - x_const) / divisor
        return res

    def getDistSum(self, center, positions):
        return sum(self.getEuclideDist(center, position) for position in positions)

    def getEuclideDist(self, a, b):
        return pow(pow(a[0]-b[0], 2) + pow(a[1]-b[1], 2), 0.5)


def main():
    inputs = [
        [[0, 1], [1, 0], [1, 2], [2, 1]]
        ,
        [[1, 1], [3, 3]]
        ,
        [[1, 1]]
        ,
        [[1, 1], [0, 0], [2, 0]]
        ,
        [[0, 1], [3, 2], [4, 5], [7, 6], [8, 9], [11, 1], [2, 12]]
        ,
        [[0, 1], [1, 0], [1, 2], [2, 1], [1, 1]]
        ,
        [[44, 23], [18, 45], [6, 73], [0, 76], [10, 50], [30, 7], [92, 59], [44, 59], [79, 45], [69, 37], [66, 63],
         [10, 78], [88, 80], [44, 87]]
    ]
    outputs = [
        4
        ,
        2.82843
        ,
        0.00000
        ,
        2.73205
        ,
        32.94036
        ,
        4
        ,
        499.28078
    ]
    sol = Solution()
    for i, input in enumerate(inputs):
        actual = sol.getMinDistSum(inputs[i])
        print(actual, outputs[i], actual - outputs[i] <= 1E-5)


if __name__ == '__main__':
    main()