题意解析
题目更抽象的描述就是,给定点集S,找到一个点A,使得A到S中所有点的欧几里得距离和distance_sum最小,求min_distance_sum。
解法
本质就是个优化问题。 目标函数为。
对x,y分别求偏导,
梯度下降,更新x, y;
,alpha是学习率。
同时为了防止振荡,在振荡发生时,调小学习率。
终止条件是更新后的距离和和更新前比较小于题目给定阈值。
代码
from typing import List
class Solution:
def getMinDistSum(self, positions: List[List[int]]) -> float:
center = positions[0]
alpha = 16
precision_error = 1E-10
current_res = self.getDistSum(center, positions)
xs = [p[0] for p in positions]
ys = [p[1] for p in positions]
dx, dy = self.getDelta(center[0], xs, center[1], ys), self.getDelta(center[1], ys, center[0], xs)
while abs(dx) > precision_error or abs(dy) > precision_error:
dx, dy = self.getDelta(center[0], xs, center[1], ys), self.getDelta(center[1], ys, center[0], xs)
new_center_x = center[0] - alpha * dx
new_center_y = center[1] - alpha * dy
new_center = [new_center_x, new_center_y]
tmp_res = self.getDistSum(new_center, positions)
while tmp_res > current_res:
if abs(tmp_res - current_res) < precision_error:
return tmp_res
alpha /= 2
new_center_x = center[0] - alpha * self.getDelta(center[0], xs, center[1], ys)
new_center_y = center[1] - alpha * self.getDelta(center[1], ys, center[0], xs)
new_center = [new_center_x, new_center_y]
tmp_res = self.getDistSum(new_center, positions)
current_res = tmp_res
center = [new_center_x, new_center_y]
# print(last_res - current_res)
return current_res
def getDelta(self, x, xs, y, ys):
res = 0
for i, x_const in enumerate(xs):
y_const = ys[i]
divisor = pow(pow(x-x_const, 2) + pow(y-y_const, 2), 0.5)
if divisor != 0:
res += (x - x_const) / divisor
return res
def getDistSum(self, center, positions):
return sum(self.getEuclideDist(center, position) for position in positions)
def getEuclideDist(self, a, b):
return pow(pow(a[0]-b[0], 2) + pow(a[1]-b[1], 2), 0.5)
def main():
inputs = [
[[0, 1], [1, 0], [1, 2], [2, 1]]
,
[[1, 1], [3, 3]]
,
[[1, 1]]
,
[[1, 1], [0, 0], [2, 0]]
,
[[0, 1], [3, 2], [4, 5], [7, 6], [8, 9], [11, 1], [2, 12]]
,
[[0, 1], [1, 0], [1, 2], [2, 1], [1, 1]]
,
[[44, 23], [18, 45], [6, 73], [0, 76], [10, 50], [30, 7], [92, 59], [44, 59], [79, 45], [69, 37], [66, 63],
[10, 78], [88, 80], [44, 87]]
]
outputs = [
4
,
2.82843
,
0.00000
,
2.73205
,
32.94036
,
4
,
499.28078
]
sol = Solution()
for i, input in enumerate(inputs):
actual = sol.getMinDistSum(inputs[i])
print(actual, outputs[i], actual - outputs[i] <= 1E-5)
if __name__ == '__main__':
main()