最优二叉搜索树详解(动态规划)

1,387 阅读4分钟

最优二叉搜索树

定义

最优二叉搜索树是搜索成本最低的二叉搜索树,为了衡量二叉搜索树的搜索成本,需要设置合适的代价函数。

可以用比较次数作为二叉搜索树的代价函数。

image.png

比较次数(代价函数):1+2*2+3*4=17

image.png

比较次数(代价函数) : 1+2*2+3*1+4*2+5*1=21

平均比较次数作为二叉搜索树的代价函数。即乘以对应的概率即可。

image.png

图一的比较次数为:1*0.04+2*(0.4+0.1)+3*(0.05+0.08+0.10+0.23)=2.42

构建二叉搜索树

二叉搜索树的叶结点是形如 (xi,xi+1) 的开区间。 在表示S的二叉搜索树中搜索一个元素x,返回的结果有两种情形:

  1. 在二叉搜索树的内结点中找到x=xi。

  2. 在二叉搜索树的叶结点中确定x∈(xi,xi+1),约定x0=-∞,xn+1=+∞。

例子:求节点1到5构成的最优二叉搜索树的平均代价

设有n=5个关键字的集合,每个k i 的概率p i 和d i 的概率q i 如表所示: image.png

平均代价

假设我们以xk节点为根

  1. 那找到xk的代价就是xk自身的概率

  2. 找到xk左边的概率就是 【 xk左子树的平均路长 加上 1(比较根节点消耗了1)】乘以命中xk左边的概率

  3. 找到xk右边的概率就是 【 xk右子树的平均路长 加上 1(比较根节点消耗了1)】乘以命中xk右边的概率

(xi~xk-1 构成的子树的平均路长 + 1) * 命中这个范围的概率 +

命中 xk 节点的概率 * 1 +

(xk+1 ~xj 构成的子树的平均路长 + 1) * 命中这个范围的概率

由构成的最优二叉树的平均代价

xi~xj 这个范围里所有节点构成的二叉搜索树中代价最小的那个, 并且每次左右子树的平均路长都用最优的二叉搜索树的平均路长

定义w:w(i,j) 用于子树保存。增加的期望搜索代价,即i到j的概率和

w[i, j] = qi-1 + pi + qi + pi+1 + ... + pj + qj

w[1,1]=q0+p1+q1

w[1,2]=q0+p1+q1+p2+q2

则:w[i,j]=w[i,j-1]+pj+qj

定义p:p[i, j]代表由 xi 到 xj 节点构成的最优二叉搜索树的平均路长

定义e: e[i, j] 代表由 xi 到 xj 节点构成的最优二叉搜索树的平均代价 e[i, j] = p[i, j] * w[i, j]

k节点为根节点:

(左子树的平均路长+1)*左边的概率+命中节点的概率+(右子树的平均路长+1)*右边的概率

e[i, j] = Min((p[i, k-1]+1)*w[i, k-1] + w[k, k] + (p[k+1, j]+1)*w[k+1,j]) // k -> [i, j]

展开:e[i, j] = Min(p[i, k-1] * w[i, k- 1] + w[i, k-1] + w[k, k] + w[k+1, j] + p[k+1, j * w[k+1, j]])

合并:e[i, j] = Min(p[i, k-1] * w[i, k- 1] + w[i, j] + p[k+1, j * w[k+1, j]])

得出 :e[i, j] = Min(e[i, k-1] + w[i, j] + e[k+1, j])

可得结论:

i 到 j 的概率等于 i 到 j-1 的概率加上 j 区间和 j 节点的概率

w[i,j]=w[i,j-1]+pj+qj

i 到 j 的代价等于 i 到 k-1 的最小代价加上 k+1 到 j 的最小代价加上 i 到 j 的概率

e[i, j] = Min(e[i, k-1] + w[i, j] + e[k+1, j])

代码思路:

  • interval表示区间 node表示节点

  • 初始化

//对概率和代价数组进行初始化
for (int i = 1; i <= length; i++) {
    //赋值有利于计算 w[1][1]=w[1][0]+n1+i1; =》w[1,0]=i0;
    w[i][i-1]=interval[i-1];
    //实际上代价就是前一个区间概率
    e[i][i-1]=interval[i-1];
}
  • 首先步长为1:

求1,2的代价:

e[i, j] = Min(e[i, k-1] + w[i, j] + e[k+1, j])

e[1,1]=e[1, 1-1] + w[1, 1] + e[1+1, 1] ---k=1

e[2,2]=e[2, 2-1] + w[2, 2] + e[2+1, 2] ---k=1

image.png image.png

代码:

package week6;

public class Search1 {
    // 节点概率
    static double[] node={-1, 0.15, 0.1, 0.05, 0.1, 0.2};
    //区间概率
    static double[] interval={0.05, 0.1, 0.05, 0.05, 0.05, 0.1};
    //节点个数
    static int length=node.length;
    //存放代价
    static double[][] e=new double[length+1][length+1];
    //w(i,j)-表示区间i到j的概率
    static double[][] w=new double[length+1][length+1];

    public static void main(String[] args) {
        //对概率和代价数组进行初始化
        for (int i = 1; i <= length; i++) {
            //赋值有利于计算 w[1][1]=w[1][0]+n1+i1; =》w[1,0]=i0;
            w[i][i-1]=interval[i-1];
            //实际上代价就是前一个区间概率
            e[i][i-1]=interval[i-1];
        }
        // printArr(w);
        // printArr(e);
        // len表示步长
        // 步长为1时,范围为:(1,2),(2,3),(3,4)
        // 步长为2时,范围为:(1,3),(2,4),(3,5)
        for (int len = 1; len < length; len++) {
            //left表示左边界,总长度=right-left+1,所有right=len+left-1要小于数组长度
            for (int left = 1; left+len-1 < length; left++) {
                //右边界等于左边界加步长
                int rigth=left+len-1;
                double min=Double.MAX_VALUE;
                //w(1,2)的值等于w(1,1)的值加上节点2和区间2的值 -- 以此类推
                w[left][rigth]=w[left][rigth-1]+interval[rigth]+node[rigth];
                //当k是根节点时,计算代价,将最小代价赋值给e
                for(int k=left;k<=rigth;k++){
                    //e[left][right]的代价等于e[left,k-1]+e[k+1][right]+w[i][j]
                    double tmp=e[left][k-1]+e[k+1][rigth]+w[left][rigth];
                    if(tmp<min){
                        min=tmp;
                    }
                }
                e[left][rigth]=min;
            }
        }
        // printArr(w);
        // printArr(e);
        //由于节点概率第一个为空,所以长度要少一个 即1,5的最小代价
        System.out.println(e[1][length-1]);
    }
    public static void printArr(double[][] arr){
        for (int i = 0; i < arr.length; i++) {
            for (int j = 0; j <arr[0].length; j++) {
                System.out.printf("%.2f  ",arr[i][j]);
            }
            System.out.println();
        }
        System.out.println();
    }
}