随机游走算法详解,每一步都是随机的叠加

1,724 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

机器学习中的随机游走

随机游走(random walk)

接近于布朗运动,是布朗运动的理想数学状态。

任何无规则行走者所带的守恒量都各自对应着一个扩散运输定律。

随机游走算法的基本思想是: 从一个或一系列顶点开始遍历一张图。在任意一个顶点,遍历者将以概率1a1-a游走到这个顶点的邻居顶点,以概率aa随机跳跃到图中的任何一个顶点,称aa为跳转发生概率,每次游走后得出一个概率分布,该概率分布刻画了图中每一个顶点被访问到的概率。用这个概率分布作为下一次游走的输入并反复迭代这一过程。当满足一定前提条件时,这个概率分布会趋于收敛。收敛后,即可以得到一个平稳的概率分布。

关键代码:

x_direction = random.choice([1, -1])
x_distance = random.choice([0, 1, 2, 3])
x_step = x_direction * x_distance

随机代码选择方向随机,选择的步长(距离)随机,使用random.choice()可以对在列表中选择这两个随机数值。

在原有位置的基础上,加上当前的随机意图(方向和步长的乘积)

next_x_value = self.x_values[-1] + x_step

那么其在x方向的移动是这样子的,假设中间的实点是当前位置,其他虚点是可能移动的轨迹

image.png

y轴上也按同样的方法进行移动

next_y_value = self.y_values[-1] + y_step

那么其在y方向上的移动就是这样的,假设中间的实点是当前位置,其他虚点是可能移动的轨迹

image.png

将x方向和y方向的游走进行叠加,总的游走效果可能出现的任意一个虚点:

image.png

全部模拟代码:

from random import choice
from matplotlib import pyplot as plt

class RandomWalk():
    def __init__(self, num_points):
        self.num_points = num_points
        self.x_values = [0]
        self.y_values = [0]
    def fill_walk(self):
        while self.num_points > len(self.x_values):
            x_direction = choice([1, -1])
            x_distance = choice([0, 1, 2, 3, 4])
            x_step = x_direction * x_distance
            
            y_direction = choice([1, -1])
            y_distance = choice([0, 1, 2, 3, 4])
            y_step = y_direction * y_distance
            
            # 拒绝原地踏步
            if x_step == 0 and y_step == 0:
                continue
            
            next_x_value = self.x_values[-1] + x_step
            next_y_value = self.y_values[-1] + y_step
            
            self.x_values.append(next_x_value)
            self.y_values.append(next_y_value)
        

i = 1
m = 100
plt.figure(dpi=128, figsize=(10, 6))

while True:
    rw = RandomWalk(10000)
    rw.fill_walk()
    
    num_points = list(range(rw.num_points))
    plt.subplot(10, 10, i)
    
    plt.xticks([])
    plt.yticks([])
    
    plt.scatter(rw.x_values, rw.y_values, c=num_points, cmap=plt.cm.Blues, s=1)
    # 突出起点
    plt.scatter(rw.x_values[0], rw.y_values[0], c='orange', s=2)
    # 重点突出终点
    plt.scatter(rw.x_values[-1], rw.y_values[-1], c='red', s=2)
    
    result_list.append([rw.x_values[-1], rw.y_values[-1]])
    
    if i == m:
        break
    i = i + 1
    

plt.show()

运行截图: