灵神力扣题单之二分算法(下)

21 阅读38分钟

1 介绍

本博客用来记录灵神力扣题单之二分算法

2 训练

2.5 最小化最大值

题目502513. 最小化两个数组中的最大值

解题思路:二分。仅适合vec1的数,仅适合vec2的数,同时适合vec1vec2的数。

C++代码如下,

class Solution {
public:
    int minimizeSet(int d1, int d2, int cnt1, int cnt2) {
        long long d3 = gcd(d1, d2);
        d3 = 1ll * d1 * d2 / d3;

        function<bool(long long)> check =[&] (long long mid) -> bool {
            long long x1 = mid / d2 - mid / d3; //1~mid中仅适合vec1的数
            long long x2 = mid / d1 - mid / d3; //1~mid中仅适合vec2的数
            long long x3 = mid - mid / d1 - mid / d2 + mid / d3; //1~mid中同时适合vec1和vec2的数
            long long a = cnt1, b = cnt2;
            a = max(0ll, a - x1);
            b = max(0ll, b - x2);
            return a + b <= x3;
        };

        long long left = 0;
        long long right = 1e18;
        long long res = -1;
        while (left <= right) {
            long long mid = (left + right) / 2;
            if (check(mid)) {
                res = mid;
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
        return res;
    }
};

python3代码如下,

import math 

class Solution:
    def minimizeSet(self, d1: int, d2: int, cnt1: int, cnt2: int) -> int:
        def check(mid: int) -> bool:
            #1~mid满足上述要求
            #适合vec1的数:它不是d1的倍数
            x1 = mid - mid // d1 
            #适合vec2的数:它不是d2的倍数
            x2 = mid - mid // d2 
            #同时适合vec1和vec2的数:它不是d1的倍数并且它不是d2的倍数
            overlap = mid // (d1 * d2 // math.gcd(d1, d2))
            overlap = int(overlap)
            x3 = mid - (mid // d1 + mid // d2 - overlap)
            return max(cnt1-x1+x3,0) + max(cnt2-x2+x3,0) <= x3 
        
        left = 0
        right = int(1e15)
        res = -1
        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                res = mid 
                right = mid - 1
            else:
                left = mid + 1
        return res 
        

题目51LCP 12. 小张刷题计划

解题思路:二分。按照题目编号顺序刷完所有的题目。每天的最大累加值<=mid,需要days天,且days<=m。记录当天的最大用时max_t_in_day。

C++代码如下,

class Solution {
public:
    int minTime(vector<int>& time, int m) {
        function<bool(int)> check =[&] (int mid) -> bool {
            //当天题目用时累加<=mid,需要days天,且days<=m 
            int curr_time = 0;
            int max_t_in_day = 0;
            int days = 1;
            for (auto t : time) {
                curr_time += t;
                max_t_in_day = max(max_t_in_day, t);
                if (curr_time - max_t_in_day > mid) {
                    days += 1;
                    curr_time = t;
                    max_t_in_day = t;
                }
            }
            return days <= m;
        };

        int left = 0;
        int right = accumulate(time.begin(), time.end(), 0);
        int res = -1;
        while (left <= right) {
            int mid = (left + right) / 2;
            if (check(mid)) {
                res = mid;
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
        return res;
    }
};

python3代码如下,

class Solution:
    def minTime(self, time: List[int], m: int) -> int:
        n = len(time) #不能排序,“按照题目编号顺序”

        def check(mid: int) -> bool:
            #每天的最大累加值<=mid,需要days天,且days<=m
            days = 1
            current_time = 0
            max_t_in_day = 0
            for t in time:
                current_time += t
                max_t_in_day = max(max_t_in_day, t) 
                if current_time-max_t_in_day > mid:
                    days += 1
                    current_time = t 
                    max_t_in_day = t 
            return days <= m 

        left = 0
        right = sum(time)
        res = -1
        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                res = mid 
                right = mid - 1
            else:
                left = mid + 1
        return res 

2.6 最大化最小值

题目522517. 礼盒的最大甜蜜度

解题思路:二分。从price中选出差的绝对值>=mid的数,总数为cnt,cnt>=k。

C++代码如下,

class Solution {
public:
    int maximumTastiness(vector<int>& price, int k) {
        int n = price.size();
        sort(price.begin(), price.end());

        function<bool(int)> check =[&] (int mid) -> bool {
            //从price中选出差的绝对值>=mid的数,总数为cnt,cnt>=k
            int cnt = 1;
            int prev = price[0];
            for (int i = 1; i < n; ++i) {
                if (price[i] - prev >= mid) {
                    cnt += 1;
                    prev = price[i];
                } else {
                    //pass
                }
            }
            return cnt >= k;
        };

        int left = 0;
        int right = 1e9;
        int res = -1;
        while (left <= right) {
            int mid = (left + right) / 2;
            if (check(mid)) {
                res = mid;
                left = mid + 1;
            } else {
                right = mid - 1;
            }
        }
        return res;
    }
};

python3代码如下,

class Solution:
    def maximumTastiness(self, price: List[int], k: int) -> int:
        price.sort()
        n = len(price)
        
        def check(mid: int) -> bool:
            #选出k个数,存在一种选法,这k个数任意两个数的差的绝对值>=mid
            #从price中选出差的绝对值>=mid的数,它们的总数cnt>=k
            cnt = 1
            prev = price[0]
            for i in range(1,n):
                if price[i] - prev >= mid:
                    cnt += 1
                    prev = price[i]
                else:
                    pass 
            return cnt >= k 
        
        left = 0
        right = int(1e9)
        res = -1
        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                res = mid 
                left = mid + 1
            else:
                right = mid - 1
        return res 

题目531552. 两球之间的磁力

解题思路:二分。

C++代码如下,

class Solution {
public:
    int maxDistance(vector<int>& position, int m) {
        sort(position.begin(), position.end());
        int n = position.size();

        function<bool(int)> check =[&] (int mid) -> bool {
            //从position中选出差的绝对值<=mid的数,它们总数>=m
            int cnt = 1;
            int prev = position[0];
            for (int i = 1; i < n; ++i) {
                if (position[i] - prev >= mid) {
                    cnt += 1;
                    prev = position[i];
                }
            }
            return cnt >= m;
        };

        int left = 0;
        int right = 1e9;
        int res = -1;
        while (left <= right) {
            int mid = (left + right) / 2;
            if (check(mid)) {
                res = mid;
                left = mid + 1;
            } else {
                right = mid - 1;
            }
        }
        return res;
    }
};

python3代码如下,

class Solution:
    def maxDistance(self, position: List[int], m: int) -> int:
        position.sort()
        n = len(position)

        def check(mid: int) -> bool:
            #从position中选出差的绝对值>=mid的数,它们的总数cnt>=m
            cnt = 1
            prev = position[0]
            for i in range(1,n):
                if position[i] - prev >= mid:
                    cnt += 1
                    prev = position[i]
            return cnt >= m 

        left = 0
        right = int(1e9)
        res = -1
        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                res = mid 
                left = mid + 1
            else:
                right = mid - 1
        return res 

题目542812. 找出最安全路径

解题思路:二分+bfs。

C++代码如下,

class Solution {
public:
    int maximumSafenessFactor(vector<vector<int>>& grid) {
        int n = grid.size();
        int m = grid[0].size();
        int dirs[4][2] = {{-1,0},{1,0},{0,-1},{0,1}};
        //先做一遍bfs,d[i][j]表示到小偷的最短距离
        vector<vector<int>> d(n, vector<int>(m, -1));
        queue<pair<int,int>> q;
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < m; ++j) {
                if (grid[i][j] == 1) {
                    q.push(make_pair(i,j));
                    d[i][j] = 0;
                }
            }
        }
        while (!q.empty()) {
            auto t = q.front();
            q.pop();
            int x = t.first, y = t.second;
            for (int k = 0; k < 4; ++k) {
                int nx = x + dirs[k][0];
                int ny = y + dirs[k][1];
                if (nx < 0 || nx >= n || ny < 0 || ny >= m) continue;
                if (d[nx][ny] != -1) continue;
                d[nx][ny] = d[x][y] + 1;
                q.push(make_pair(nx,ny));
            }
        }

        function<bool(int)> check =[&] (int mid) -> bool {
            //只能走d[i][j]>=mid的格子,能到达终点
            if (d[0][0] < mid) return false; //特判
            vector<vector<bool>> st(n, vector<bool>(m, false));
            queue<pair<int,int>> q;
            q.push(make_pair(0,0));
            st[0][0] = true;
            while (!q.empty()) {
                auto t = q.front();
                q.pop();
                int x = t.first, y = t.second;
                if (x == n-1 && y == m-1) { //走到了终点
                    return true;
                }
                for (int k = 0; k < 4; ++k) {
                    int nx = x + dirs[k][0];
                    int ny = y + dirs[k][1];
                    if (nx < 0 || nx >= n || ny < 0 || ny >= m) continue;
                    if (d[nx][ny] < mid) continue;
                    if (st[nx][ny]) continue;
                    q.push(make_pair(nx,ny));
                    st[nx][ny] = true;
                }
            }
            return false;
        };

        int left = 0;
        int right = n + m;
        int res = -1;
        while (left <= right) {
            int mid = (left + right) / 2;
            if (check(mid)) {
                res = mid;
                left = mid + 1;
            } else {
                right = mid - 1;
            }
        }
        return res;
    }
};

python3代码如下,

class Solution:
    def maximumSafenessFactor(self, grid: List[List[int]]) -> int:
        n = len(grid)
        m = len(grid[0])

        #先做一遍bfs,d[i][j]表示到小偷的最小距离
        d = [[-1] * m for _ in range(n)]
        q = collections.deque()
        for i in range(n):
            for j in range(m):
                if grid[i][j] == 1:
                    q.append([i,j])
                    d[i][j] = 0
        dirs = [[-1,0],[1,0],[0,-1],[0,1]]
        while len(q) > 0:
            x, y = q.popleft()
            for k in range(4):
                nx = x + dirs[k][0]
                ny = y + dirs[k][1]
                if nx < 0 or nx >= n or ny < 0 or ny >= m:
                    continue 
                if d[nx][ny] != -1:
                    continue 
                d[nx][ny] = d[x][y] + 1
                q.append([nx,ny])
        
        def check(mid: int) -> bool:
            #只能走d[i][j]>=mid的格子
            if d[0][0] < mid: #特判
                return False 
            q = collections.deque()
            st = [[False]*m for _ in range(n)]
            q.append([0,0])
            st[0][0] = True 
            while len(q) > 0:
                x, y = q.popleft()
                if x == n-1 and y == m-1:
                    return True #走到了终点
                for k in range(4):
                    nx = x + dirs[k][0]
                    ny = y + dirs[k][1] 
                    if nx < 0 or nx >= n or ny < 0 or ny >= m:
                        continue 
                    if d[nx][ny] < mid:
                        continue 
                    if st[nx][ny]:
                        continue 
                    q.append([nx,ny]) 
                    st[nx][ny] = True 
            return False  

        left = 0
        right = n + m 
        res = -1
        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                res = mid
                left = mid + 1
            else:
                right = mid - 1
        return res 

题目552528. 最大化城市的最小电量

解题思路:二分+贪心+差分+前缀和。当nums[i],不满足条件,在i+r处建立供电站。

C++代码如下,

class Solution {
public:
    long long maxPower(vector<int>& stations, int r, int k) {
        int n = stations.size();
        vector<long long> s(n+1,0);
        for (int i = 1; i < n+1; ++i) {
            s[i] = s[i-1] + stations[i-1];
        }
        vector<long long> nums(n,0);
        for (int i = 0; i < n; ++i) {
            int end = min(i+r+1, n);
            int start = max(i-r, 0);
            nums[i] = s[end] - s[start];
        }

        function<bool(long long)> check =[&] (long long mid) -> bool {
            //每个城市的电量都大于等于mid,需要建造cnt个供电站,最终cnt <= k 
            vector<long long> diff(n, 0);
            long long sumd = 0;
            long long cnt = 0;
            for (int i = 0; i < n; ++i) {
                sumd += diff[i];
                long long x = mid - sumd - nums[i];
                if (x > 0) {
                    cnt += x;
                    sumd += x;
                    if (i+r*2+1 < n) {
                        diff[i+r*2+1] -= x;
                    }
                }
            }
            return cnt <= k;
        };

        long long left = ranges::min(nums);
        long long right = left + k + 1;
        long long res = -1;
        while (left <= right) {
            long long mid = (left + right) / 2;
            if (check(mid)) {
                res = mid;
                left = mid + 1;
            } else {
                right = mid - 1;
            }
        }
        return res;
    }
};

python3代码如下,

class Solution:
    def maxPower(self, stations: List[int], r: int, k: int) -> int:
        n = len(stations)
        s = [0] * (n+1)
        for i in range(1,n+1):
            s[i] = s[i-1] + stations[i-1]
        nums = [0] * n 
        for i in range(n):
            nums[i] = s[min(i+r+1,n)]-s[max(i-r,0)] #[i-r,i+r] ===> (i-r,i+r+1] 

        def check(mid):
            #使得每个城市的电量都大于等于mid,需要新建cnt个供电站
            diff = [0] * n
            sumd = 0 #差分数目累加和
            cnt = 0 #需要新建的供电站数目
            for i,num in enumerate(nums):
                sumd += diff[i]
                y = mid - sumd - num 
                if y > 0:
                    cnt += y #需要新建y座供电站
                    if cnt > k:
                        return False 
                    sumd += y 
                    if i+r*2+1 < n:
                        diff[i+r*2+1] -= y 
            return True
            
        left = min(nums)
        right = left + k + 1
        res = -1
        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                res = mid 
                left = mid + 1
            else:
                right = mid - 1
        return res 

2.7 第K小/大

题目56378. 有序矩阵中第 K 小的元素

解题思路:二分。依据lower_bound(),计算大于等于mid的最小值。left = matrix[0][0]right = matrix[n-1][n-1],从最后一行第一列开始遍历。

C++代码如下,

class Solution {
public:
    int kthSmallest(vector<vector<int>>& matrix, int k) {
        int n = matrix.size();

        function<bool(int)> check =[&] (int mid) -> bool {
            //至少要找到k个小于等于mid的元素
            int i = n-1, j = 0;
            int count = 0; //矩阵中有多少个数小于等于mid
            while (i >= 0 && j < n) {
                if (matrix[i][j] <= mid) {
                    count += i+1;
                    j += 1;
                } else {
                    i -= 1;
                }
            }
            return count >= k; 
        };

        int left = matrix[0][0];
        int right = matrix[n-1][n-1];
        int res = -1;
        while (left <= right) {
            int mid = (left + right) / 2;
            if (check(mid)) {
                res = mid; //求满足条件的最小值
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
        return res;
    }
};

python3代码如下,

class Solution:
    def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
        n = len(matrix)

        def check(mid: int) -> int:
            #矩阵中小于等于mid的数的个数,它大于等于k
            i = n-1
            j = 0
            count = 0
            #从最后一行第一列开始,计算每一列中小于等于mid的个数
            while i >= 0 and j < n:
                if matrix[i][j] <= mid:
                    count += i + 1 
                    j += 1
                else:
                    i -= 1
            return count >= k 

        left = matrix[0][0]
        right = matrix[n-1][n-1]
        res = -1 
        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                res = mid #计算大于等于k的最小值
                right = mid - 1
            else:
                left = mid + 1
        return res 

题目57668. 乘法表中第k小的数

解题思路:二分。统计矩阵中<=mid的数目cnt,最终cnt >= k。求mid的最小值。

C++代码如下,

class Solution {
public:
    int findKthNumber(int n, int m, int k) {
        function<bool(int)> check =[&] (int mid) -> bool {
            int cnt = 0;
            int i = n, j = 1;
            while (i >= 1 && j <= m) {
                if (i*j <= mid) {
                    cnt += i;
                    j += 1;
                } else {
                    i -= 1;
                }
            }
            return cnt >= k;
        };

        int left = 1;
        int right = n * m;
        int res = -1;
        while (left <= right) {
            int mid = (left + right) / 2;
            if (check(mid)) {
                res = mid;
                right = mid-1;
            } else {
                left = mid+1;
            }
        }
        return res;
    }
};

python3代码如下,

class Solution:
    def findKthNumber(self, n: int, m: int, k: int) -> int:
        def check(mid: int) -> bool:
            #至少有k个数不超过mid
            cnt = 0 
            i = n 
            j = 1
            while i >= 1 and j <= m:
                if i*j <= mid:
                    cnt += i
                    j += 1
                else:
                    i -= 1
            return cnt >= k 
        
        left = 1
        right = n * m
        res = -1
        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                res = mid #至少有k个数不超过mid
                #求mid的最小值
                right = mid - 1
            else:
                left = mid + 1
        return res 

题目58719. 找出第 K 小的数对距离

解题思路:二分+双指针。至少k个数对<=mid。求最小值。

C++代码如下,

class Solution {
public:
    int smallestDistancePair(vector<int>& nums, int k) {
        int n = nums.size();
        sort(nums.begin(), nums.end());

        function<bool(int)> check =[&] (int mid) -> bool {
            int cnt = 0;
            int j = 0;
            for (int i = 0; i < n; ++i) {
                while (j < n && nums[j]-nums[i] <= mid) {
                    j += 1;
                }
                cnt += j - i - 1;
            }
            return cnt >= k;
        };

        int left = 0;
        int right = ranges::max(nums) - ranges::min(nums);
        int res = -1;
        while (left <= right) {
            int mid = (left + right) / 2;
            if (check(mid)) { //至少k个数对<=mid
                res = mid;
                right = mid - 1; //求最小值
            } else {
                left = mid + 1;
            }
        }
        return res;
    }
};

python3代码如下,

class Solution:
    def smallestDistancePair(self, nums: List[int], k: int) -> int:
        nums.sort()
        n = len(nums) 

        def check(mid: int) -> bool:
            j = 0 
            cnt = 0
            for i in range(n):
                while j < n and nums[j] - nums[i] <= mid:
                    j += 1
                cnt += j - i - 1
            return cnt >= k

        left = 0
        right = max(nums) - min(nums)
        res = -1
        while left <= right:
            mid = (left + right) // 2
            if check(mid): #至少k个数对<=mid
                res = mid 
                right = mid - 1 #求最小值
            else:
                left = mid + 1
        return res 

题目59878. 第 N 个神奇数字

解题思路:二分+容斥原理。

C++代码如下,

class Solution {
public:
    int nthMagicalNumber(int n, int a, int b) {
        int c = a * b / gcd(a, b);

        function<bool(long long)> check =[&] (long long mid) -> bool {
            long long cnt = mid / a + mid / b - mid / c;
            return cnt >= n;
        };

        long long left = n;
        long long right = 1e14;
        long long res = -1;
        while (left <= right) {
            long long mid = (left + right) / 2;
            if (check(mid)) { //至少k个数<=mid
                res = mid;
                right = mid - 1;//求最小值
            } else {
                left = mid + 1;
            }
        }
        return res % (int)(1e9+7);
    }
};

python3代码如下,

import math 

class Solution:
    def nthMagicalNumber(self, n: int, a: int, b: int) -> int:
        c = a * b // math.gcd(a,b) #a和b的最小公倍数
        
        def check(mid: int) -> bool:
            cnt = mid // a + mid // b - mid // c #容斥原理
            return cnt >= n 
        
        left = 0
        right = int(1e14)
        res = -1
        while left <= right:
            mid = (left + right) // 2
            if check(mid): #求至少n个数满足要求
                res = mid 
                right = mid - 1 #求最小值
            else:
                left = mid + 1
        return res % int(1e9 + 7)

题目601201. 丑数 III

解题思路:二分+容斥定理。

C++代码如下,

class Solution {
public:
    long long lcm(long long x, long long y) {
        return x * y / gcd(x, y);
    }

    int nthUglyNumber(int n, int a, int b, int c) {
        long long x = lcm(a, b);
        long long y = lcm(a, c);
        long long z = lcm(b, c);
        long long u = lcm(x, c);

        function<bool(int)> check =[&] (long long mid) -> bool {
            //至少有n个数
            long long cnt = mid / a + mid / b + mid / c - mid / x - mid / y - mid / z + mid / u;
            return cnt >= n;
        };
        
        int left = n;
        int right = 2e9 + 10;
        int res = -1;
        while (left <= right) {
            long long mid = (1ll * left + right) / 2;
            if (check(mid)) { //至少有n个数
                res = mid;
                right = mid - 1; //求最小值
            } else {
                left = mid + 1;
            }
        }
        return res;
    }
};

python3代码如下,

import math 

class Solution:
    def nthUglyNumber(self, n: int, a: int, b: int, c: int) -> int:
        def lcm(x: int, y: int) -> int:
            return x * y // math.gcd(x,y)
        
        x = lcm(a, b)
        y = lcm(a, c)
        z = lcm(b, c)
        u = lcm(x, c)

        def check(mid: int) -> bool:
            #print(f"mid={mid}.")
            cnt = mid // a + mid // b + mid // c - mid // x - mid // y - mid // z + mid // u 
            #print(f"cnt={cnt},cnt>=n ==> {cnt>=n}.")
            return cnt >= n 

        left = n 
        right = int(2e9 + 10)
        res = -1
        while left <= right:
            mid = (left + right) // 2
            if check(mid): #至少n个数满足要求
                res = mid 
                right = mid - 1 #求最小值
            else:
                left = mid + 1
        return res 

题目61793. 阶乘函数后 K 个零

解题思路:二分+数论。

(1)x!中末尾0的个数:它是由因子10的个数决定的。因为10是由2和5相乘而来的,而阶乘中2的因子总是比5多。因此,末尾0的个数等于因子5的出现次数。

(2)为了计算阶乘中包含多少个5的因子,我们可以从5的倍数入手。每当一个数是5的倍数时,它会为阶乘贡献至少一个5的因子。接着。如果一个数是25的倍数,它会多贡献一个5的因子。类似地,125的倍数会再贡献一个5的因子。故x!中5的因子的个数为,x//5 + x//25 + x//125 + ...

分别表示:

  1. 1~x中所有能被5整除的数的个数。
  2. 1~x中所有能被35整除的数的个数。
  3. 1~x中所有能被125整除的数的个数。
  4. 以此类推……

C++代码如下,

class Solution {
public:
    long long f(long long x) {
        long long res = 0;
        while (x > 0) {
            res += x / 5;
            x /= 5;
        }
        return res;
    }

    int search_left(int k) {
        //f[x]=k,求最小的x
        long long left = 0;
        long long right = 5ll * (k + 1);
        long long res = -1;
        while (left <= right) {
            long long mid = (left + right) / 2;
            if (f(mid) >= k) {
                res = mid;
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
        return res;
    }

    int preimageSizeFZF(int k) {
        return search_left(k+1) - search_left(k);
    }
};

python3代码如下,

class Solution:
    def preimageSizeFZF(self, k: int) -> int:
        #x!末尾中0的数量,等价于因子5的数目
        #f[x]=x//5+x//25+x//125+...

        def f(x: int) -> int:
            #计算f[x]
            res = 0
            while x > 0:
                res += x // 5
                x //= 5
            return res 
        
        def search_left(k: int) -> int:
            #f[x]>=k,求x的最小值
            left = 0
            right = 5 * (k+1)
            res = -1
            while left <= right:
                mid = (left + right) // 2
                if f(mid) >= k:
                    res = mid 
                    right = mid - 1
                else:
                    left = mid + 1
            return res 
        
        def search_right(k: int) -> int:
            #f[x] > k, 求x的最小值
            left = 0
            right = 5 * (k + 1)
            res = -1
            while left <= right:
                mid = (left + right) // 2
                if f(mid) > k:
                    res = mid 
                    right = mid - 1
                else:
                    left = mid + 1
            return res 
        
        return search_right(k) - search_left(k)

题目62373. 查找和最小的 K 对数字

解题思路:贪心。

(1)堆的初始化:将nums1中的前k个元素与nums2[0]配对,放入一个最小堆中。堆中的每个元素是一个三元组(total_sum, i, j),表示nums1[i] + nums2[j]的和、索引ij

(2)从堆中取数对:每次从堆中取出和最小的数对(nums1[i], nums2[j]),然后将下一个数对(nums1[i], nums2[j+1])放入堆中(如果j+1是有效的索引)。

C++代码如下,

class Solution {
public:
    vector<vector<int>> kSmallestPairs(vector<int>& nums1, vector<int>& nums2, int k) {
        struct node {
            int s;
            int i;
            int j;

            bool operator> (const node& node1) const {
                return s > node1.s;
            }

            node(int s, int i, int j) : s(s), i(i), j(j) {}
        };
        
        int n = nums1.size();
        int m = nums2.size();

        priority_queue<node,vector<node>,greater<node>> hp;
        for (int i = 0; i < min(n, k); ++i) {
            node t(nums1[i]+nums2[0],i,0);
            hp.push(t);
        }

        vector<vector<int>> res;
        while (k > 0 && !hp.empty()) {
            auto t = hp.top();
            hp.pop();
            int total_sum = t.s;
            int i = t.i;
            int j = t.j;
            res.push_back({nums1[i],nums2[j]});
            k -= 1;
            if (j+1 < m) {
                hp.push(node(nums1[i]+nums2[j+1],i,j+1));
            }
        }
        return res;
    }
};

python3代码如下,

import heapq

class Solution:
    def kSmallestPairs(self, nums1: List[int], nums2: List[int], k: int) -> List[List[int]]:
        hp = []
        for i in range(min(k,len(nums1))):
            heapq.heappush(hp, (nums1[i]+nums2[0], i, 0))
        res = []
        while k > 0 and len(hp) > 0:
            total_sum, i, j = heapq.heappop(hp)
            res.append([nums1[i], nums2[j]])
            k -= 1
            if j + 1 < len(nums2):
                heapq.heappush(hp, (nums1[i]+nums2[j+1], i, j+1))
        return res 

题目631439. 有序矩阵中的第 k 个最小数组和

解题思路:二分+bfs+小根堆。

C++代码如下,

class Solution {
public:
    int kthSmallest(vector<vector<int>>& mat, int k) {
        int n = mat.size();
        int m = mat[0].size();

        struct node {
            int s;
            vector<int> idxes;

            node(int in_s, vector<int> in_idxes) : s(in_s), idxes(in_idxes) {}

            bool operator> (const node& node1) const { //小根堆,它重载大于号
                return s > node1.s;
            }
        };

        function<bool(int)> check =[&] (int mid) -> bool {
            priority_queue<node,vector<node>,greater<node>> hp;
            int s = 0;
            for (auto& row : mat) {
                s += row[0];
            }
            hp.push(node(s,vector<int>(n,0)));
            set<vector<int>> seen;
            int cnt = 0;
            while (!hp.empty()) {
                node t = hp.top();
                hp.pop();
                int curr_s = t.s;
                vector<int> curr_idxes = t.idxes;
                cnt += 1;
                if (curr_s > mid) {
                    break;
                }
                if (cnt >= k) { //提前判断
                    return true;
                }
                for (int i = 0; i < n; ++i) {
                    int j = curr_idxes[i];
                    if (j + 1 < m) {
                        vector<int> new_idxes = curr_idxes;
                        new_idxes[i] += 1;
                        int new_s = curr_s - mat[i][j] + mat[i][j+1];
                        if (seen.find(new_idxes) == seen.end()) {
                            seen.insert(new_idxes);
                            hp.push(node(new_s,new_idxes));
                        }
                    }
                }
            }
            return false;
        };

        int left = 0;
        int right = 0;
        for (auto& row : mat) {
            left += row[0];
            right += row.back();
        }
        int res = -1;
        while (left <= right) {
            int mid = (left + right) / 2;
            if (check(mid)) {
                res = mid;
                right = mid - 1; //求最小值
            } else {
                left = mid + 1;
            }
        }
        return res;
    }
};

python3代码如下,

import heapq

class Solution:
    def kthSmallest(self, mat: List[List[int]], k: int) -> int:
        n, m = len(mat), len(mat[0])

        def check(mid: int) -> bool:
            #数组和小于等于mid,这样的数组有cnt个
            hp = [(sum([row[0] for row in mat]), [0]*n)]
            visited = set(tuple([0]*n))
            #heapq.heapify(hp)
            cnt = 0 #
            while len(hp) > 0:
                curr_s, curr_idxes = heapq.heappop(hp)
                if curr_s > mid:
                    break 
                cnt += 1
                if cnt >= k: #提前判断
                    return True 
                for i in range(n):
                    j = curr_idxes[i]
                    if j + 1 < m:
                        next_idxes = list(curr_idxes)
                        next_idxes[i] += 1
                        next_s = curr_s - mat[i][j] + mat[i][j+1]
                        next_idxes = tuple(next_idxes)
                        if next_idxes not in visited:
                            heapq.heappush(hp, (next_s, list(next_idxes)))
                            visited.add(next_idxes)

            return False 

        left = sum([row[0] for row in mat])
        right = sum([row[-1] for row in mat])
        #print(f"left = {left}, right = {right}.")
        res = -1
        while left <= right:
            mid = (left + right) // 2
            if check(mid): #和小于等于mid的数组个数,它大于等于k
                res = mid 
                right = mid - 1 #求最小值
            else:
                left = mid + 1
        return res 

题目64786. 第 K 个最小的质数分数

解题思路:直接模拟。

C++代码如下,

class Solution {
public:
    vector<int> kthSmallestPrimeFraction(vector<int>& arr, int k) {
        struct node {
            double val;
            int a;
            int b;

            node(double in_val, int in_a, int in_b) : val(in_val), a(in_a), b(in_b) {}

            bool operator< (const node& node1) const {
                return val < node1.val;
            }
        };

        int n = arr.size();
        vector<node> new_arr;
        for (int i = 0; i < n; ++i) {
            for (int j = i+1; j < n; ++j) {
                double val = 1.0 * arr[i] / arr[j];
                new_arr.push_back(node(val,arr[i],arr[j]));
            }
        }

        sort(new_arr.begin(), new_arr.end());

        int x = new_arr[k-1].a;
        int y = new_arr[k-1].b;
        return {x,y};
    }
};

python3代码如下,

class Solution:
    def kthSmallestPrimeFraction(self, arr: List[int], k: int) -> List[int]:
        n = len(arr)
        new_arr = []
        for i in range(n):
            for j in range(i+1,n):
                new_arr.append([arr[i]/arr[j],arr[i],arr[j]])
        new_arr.sort()
        x, y = new_arr[k-1][1], new_arr[k-1][2]
        return [x,y]

题目653116. 单面值组合的第 K 小金额

解题思路:二分+容斥原理。

假设有n个集合A1, A2, ..., An,它们的并集用符号∪表示,那么它们的并集的元素个数可以用以下公式计算:

|A1 ∪ A2 ∪ ... ∪ An| = 
|A1| + |A2| + ... + |An| // 单个集合的元素个数之和 
- |A1 ∩ A2| - |A1 ∩ A3| - ... // 两两相交部分的元素个数之和 
+ |A1 ∩ A2 ∩ A3| + ... // 三三相交部分的元素个数之和 - ... 
+ (-1)^(n-1) * |A1 ∩ A2 ∩ ... ∩ An| // 所有集合相交部分的元素个数

C++代码如下,

class Solution {
public:
    long long findKthSmallest(vector<int>& coins, int k) {
        int n = coins.size();
        long long m = 1ll << n;

        function<bool(long long)> check =[&](long long mid) {
            //1~mid符合要求的数的个数cnt>=k
            long long cnt = 0;
            for (int i = 1; i < m; ++i) {
                bool flag = true;
                int bit_1_count = 0;
                long long curr = 1; //当前选法的最小公倍数
                for (int j = 0; j < n; ++j) {
                    if (i >> j & 1) {
                        curr = lcm(curr, 1ll * coins[j]);
                        bit_1_count += 1;
                        if (curr > mid) {
                            flag = false;
                            break;
                        }
                    }
                }
                if (flag) {
                    if (bit_1_count % 2) {
                        cnt += mid / curr;
                    } else {
                        cnt -= mid / curr;
                    }
                }
            }
            return cnt >= k;
        };

        long long left = 0;
        long long right = 1ll * k * ranges::min(coins);
        long long res = -1;
        while (left <= right) {
            long long mid = (left + right) / 2;
            if (check(mid)) {
                res = mid;
                right = mid - 1; //求最小值
            } else {
                left = mid + 1;
            }
        }
        return res;
    }
};

python3代码如下,

class Solution:
    def findKthSmallest(self, coins: List[int], k: int) -> int:
        n = len(coins)
        def check(target: int):
            #容斥原理
            cnt = 0
            for i in range(1,1<<n): #至少选择1个数
                curr = 1 #当前选法的最小公倍数
                for j in range(n):
                    if i >> j & 1:
                        curr = math.lcm(curr, coins[j])
                        if curr > mid:
                            break 
                else:
                    if i.bit_count() % 2 == 1:
                        cnt += mid // curr 
                    else:
                        cnt -= mid // curr 
            return cnt >= k 

        
        left = 0
        right = min(coins) * k 
        res = -1
        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                res = mid 
                right = mid - 1
            else:
                left = mid + 1
        return res 

题目663134. 找出唯一性数组的中位数

解题思路:二分+滑动窗口。第k小的数,k从1开始编号。

C++代码如下,

class Solution {
public:
    int medianOfUniquenessArray(vector<int>& nums) {
        int n = nums.size();
        long long m = (1ll + n) * n / 2;
        long long k = (m + 1) / 2;

        function<bool(int)> check =[&] (int mid) -> bool {
            //区间内不同的数小于等于mid,这样的区间个数大于等于k
            unordered_map<int,int> map_val_cnt;
            map_val_cnt[nums[0]] += 1;
            int curr = 1; //当前区间内不同的数的个数
            long long res = 0;
            int i = 0;
            int j = 0; 
            while (i < n) {
                while (j < n && curr <= mid) {
                    j += 1;
                    if (j < n) {
                        map_val_cnt[nums[j]] += 1;
                        if (map_val_cnt[nums[j]] == 1) {
                            curr += 1;
                        }
                    }
                }
                res += j - i;
                if (res >= k) { //提前退出
                    return true;
                }
                map_val_cnt[nums[i]] -= 1;
                if (map_val_cnt[nums[i]] == 0) {
                    curr -= 1;
                }
                i += 1;
            }
            return false;
        };

        int left = 0;
        int right = n;
        int res = -1;
        while (left <= right) {
            int mid = (left + right) / 2;
            if (check(mid)) {
                res = mid;
                right = mid - 1;//求最小值
            } else {
                left = mid + 1;
            }
        } 
        return res;
    }
};

python3代码如下,

class Solution:
    def medianOfUniquenessArray(self, nums: List[int]) -> int:
        n = len(nums)
        n = (1 + n) * n // 2 #所有数组的总数
        k = (n + 1) // 2 #第k个数,从1开始编号
        #print(f"k = {k}.")

        def check(mid: int) -> bool:
            #区间内不同数的个数小于等于mid,这样的区间有cnt个,cnt>=k
            #print(f"mid = {mid}.")
            cnt = 0
            n = len(nums) 
            i = 0
            j = 0 
            map_val_cnt = collections.defaultdict(int)
            curr = 1 #当前滑动区间内,不相同的数的数目
            map_val_cnt[nums[0]] += 1
            while i < n:
                while j < n and curr <= mid:
                    j += 1
                    if j < n:
                        map_val_cnt[nums[j]] += 1
                        if map_val_cnt[nums[j]] == 1:
                            curr += 1
                cnt += j - i #[i,j)
                #print(f"i = {i}, j = {j}.")
                map_val_cnt[nums[i]] -= 1
                if map_val_cnt[nums[i]] == 0:
                    curr -= 1
                i += 1
            #print(f"cnt = {cnt}.")
            return cnt >= k 

        left = 1
        right = len(nums)
        res = -1
        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                res = mid 
                right = mid - 1 #求最小值
            else:
                left = mid + 1
        return res 

题目672040. 两个有序数组的第 K 小乘积

解题思路:分负数、正数部分,然后再进行二分。

C++代码如下,

class Solution {
public:
    long long func1(vector<int>& pos1, vector<int>& pos2, vector<int>& neg1, vector<int>& neg2, long long k) {
        function<bool(long long)> check =[&] (long long mid) -> bool {
            //乘积小于等于mid的数对,数目为cnt,cnt>=k
            long long cnt = 0;
            for (int x : pos1) {
                long long y = mid / x;
                // if (y > INT_MAX) {
                //     continue;
                // }
                auto iter = upper_bound(neg2.begin(), neg2.end(), y);
                cnt += distance(neg2.begin(), iter);
                if (cnt >= k) {
                    return true; //提前判断
                }
            }

            for (int x : neg1) {
                long long y = mid / (-x);
                // if (y > INT_MAX) {
                //     continue;
                // }
                auto iter = upper_bound(pos2.begin(), pos2.end(), y);
                cnt += distance(pos2.begin(), iter);
                if (cnt >= k) {
                    return true; //提前判断
                }
            }
            return false;
        };

        long long left = 0;
        long long right = 1e10;
        long long res = -1;
        while (left <= right) {
            long long mid = (left + right) / 2;
            if (check(mid)) {
                res = mid;
                right = mid - 1; 
            } else {
                left = mid + 1;
            }
        }
        return res;
    }

    long long func2(vector<int>& pos1, vector<int>& pos2, vector<int>& neg1, vector<int>& neg2, long long k) {
        function<bool(long long)> check =[&] (long long mid) -> bool {
            //乘积小于等于mid的数对,数目为cnt,cnt>=k
            long long cnt = 0;
            for (int x : pos1) {
                long long y = mid / x;
                // if (y > INT_MAX) {
                //     continue;
                // }
                auto iter = upper_bound(pos2.begin(), pos2.end(), y);
                cnt += distance(pos2.begin(), iter);
                if (cnt >= k) {
                    return true; //提前判断
                }
            }

            for (int x : neg1) {
                long long y = mid / (-x);
                // if (y > INT_MAX) {
                //     continue;
                // }
                auto iter = upper_bound(neg2.begin(), neg2.end(), y);
                cnt += distance(neg2.begin(), iter);
                if (cnt >= k) {
                    return true; //提前判断
                }
            }
            return false;
        };

        long long left = 0;
        long long right = 1e10;
        long long res = -1;
        while (left <= right) {
            long long mid = (left + right) / 2;
            if (check(mid)) {
                res = mid;
                right = mid - 1; 
            } else {
                left = mid + 1;
            }
        }
        return res;
    }

    long long kthSmallestProduct(vector<int>& nums1, vector<int>& nums2, long long k) {
        int n = nums1.size(), m = nums2.size();

        auto iter = lower_bound(nums1.begin(), nums1.end(), 0);
        int i0 = distance(nums1.begin(), iter); //nums1中负数的个数
        iter = upper_bound(nums1.begin(), nums1.end(), 0);
        int i1 = distance(nums1.begin(), iter); //nums1中非正数的个数
        
        iter = lower_bound(nums2.begin(), nums2.end(), 0);
        int j0 = distance(nums2.begin(), iter); //nums2中负数的个数
        iter = upper_bound(nums2.begin(), nums2.end(), 0);
        int j1 = distance(nums2.begin(), iter); //nums2中非正数的个数

        vector<int> neg1, neg2, pos1, pos2;
        neg1.insert(neg1.end(), nums1.begin(), nums1.begin()+i0);
        pos1.insert(pos1.end(), nums1.begin()+i1, nums1.end());
        neg2.insert(neg2.end(), nums2.begin(), nums2.begin()+j0);
        pos2.insert(pos2.end(), nums2.begin()+j1, nums2.end());

        // cout << "neg1: ";
        // for (auto x : neg1) cout << x << " ";
        // cout << endl;
        // cout << "pos1: ";
        // for (auto x : pos1) cout << x << " ";
        // cout << endl;        
        // cout << "neg2: ";
        // for (auto x : neg2) cout << x << " ";
        // cout << endl;
        // cout << "pos2: ";
        // for (auto x : pos2) cout << x << " ";
        // cout << endl;

        //将neg2符号变成正的,方便search
        for (int& x : neg2) {
            x = -x;
        }
        reverse(neg2.begin(), neg2.end());

        long long cnt1 = 1ll * i0 * (m -j1) + 1ll * j0 * (n-i1); //乘积为负数的个数
        long long cnt2 = 1ll * i0 * j0 + 1ll * (n-i1) * (m-j1); //乘积为正数的个数
        long long cnt3 = 1ll * n * m - cnt1 - cnt2; //乘积为0的个数

        //cout << "cnt1 = " << cnt1 << ", cnt2 = " << cnt2 << ", cnt3 = " << cnt3 << endl;

        if (k <= cnt1) {
            k = cnt1 - k + 1;
            //在负数中找到第k小的数
            return -func1(pos1, pos2, neg1, neg2, k);
        } else if (k <= cnt1 + cnt3) {
            return 0;
        } else {
            k -= cnt1 + cnt3;
            //在正数中找到第k小的数
            return func2(pos1, pos2, neg1, neg2, k);
        }
    }
};

python3代码如下,

#self
class Solution:
    def func1(self, pos1, pos2, neg1, neg2, k):
        def check(mid: int) -> bool:
            #乘积小于等于mid的数对,个数为cnt,cnt>=k
            cnt = 0
            for x in pos1:
                y = bisect.bisect_right(neg2, mid // x)
                cnt += y
            for x in neg1:
                x = -x 
                y = bisect.bisect_right(pos2, mid // x)
                cnt += y 
            return cnt >= k 

        left = 0
        right = int(1e10)
        res = inf 
        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                res = mid 
                right = mid - 1
            else:
                left = mid + 1 
        return res 

    def func2(self, pos1, pos2, neg1, neg2, k):
        def check(mid: int) -> bool:
            #乘积小于等于mid的数对,它的个数为cnt,cnt>=k
            cnt = 0
            for x in pos1:
                y = bisect.bisect_right(pos2, mid // x)
                cnt += y 
            for x in neg1:
                x = -x
                y = bisect.bisect_right(neg2, mid // x)
                cnt += y
            return cnt >= k 
        
        left = 0
        right = int(1e10)
        res = inf 
        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                res = mid 
                right = mid - 1
            else:
                left = mid + 1
        return res 

    def kthSmallestProduct(self, nums1: List[int], nums2: List[int], K: int) -> int:
        #求num1*num2中第k小的数
        i0 = bisect.bisect_left(nums1, 0) #nums1中的负数的个数
        i1 = bisect.bisect_right(nums1, 0) #nums1中非正数的个数
        j0 = bisect.bisect_left(nums2, 0) #nums2中的负数的个数
        j1 = bisect.bisect_right(nums2, 0) #nums2中非正数的个数 
        n, m = len(nums1), len(nums2)

        cnt1 = i0 * (m-j1) + j0 * (n-i1) #负数乘积的个数
        cnt2 = i0 * j0 + (n-i1) * (m-j1) #正数乘积的个数
        cnt3 = n * m - cnt1 - cnt2 #乘积为0的个数
        
        neg1, neg2 = nums1[0:i0], nums2[0:j0]
        pos1, pos2 = nums1[i1:], nums2[j1:]

        #将neg2变成正数方便search
        for i in range(len(neg2)):
            neg2[i] = -neg2[i]
        neg2.reverse()

        #print(f"cnt1 = {cnt1}, cnt2 = {cnt2}, cnt3 = {cnt3}.")

        k = K 
        if k <= cnt1:
            k = cnt1 - k + 1
            #求乘积为负数中的第k小的数
            return -self.func1(pos1, pos2, neg1, neg2, k)
        elif k <= cnt1 + cnt3:
            return 0
        else:
            k -= cnt1 + cnt3 
            #求乘积为整数中第k小的数
            return self.func2(pos1, pos2, neg1, neg2, k) 

#参考版本
class Solution:
    def kthSmallestProduct(self, nums1: List[int], nums2: List[int], K: int) -> int:
        def count_pairs(neg1, neg2, pos1, pos2, t, k, is_neg):
            #乘积<=t的个数
            cnt = 0
            if is_neg:
                for v in pos1:
                    cnt += bisect_right(neg2, t // v) #nge2是一个仅包含正数的数组
                for v in neg1:
                    cnt += bisect_right(pos2, t // -v)
            else:
                for v in pos1:
                    cnt += bisect_right(pos2, t // v)
                for v in neg1:
                    cnt += bisect_right(neg2, t // -v)
            return cnt >= k

        def binary_search_neg(neg1, neg2, pos1, pos2, k):
            def condition(t):
                return count_pairs(neg1, neg2, pos1, pos2, t, k, True)
            
            low, high = 0, 10**10
            while low < high:
                mid = (low + high) // 2
                if condition(mid):
                    high = mid
                else:
                    low = mid + 1
            return low

        def binary_search_pos(neg1, neg2, pos1, pos2, k):
            def condition(t):
                return count_pairs(neg1, neg2, pos1, pos2, t, k, False)
            
            low, high = 0, 10**10
            while low < high:
                mid = (low + high) // 2
                if condition(mid):
                    high = mid
                else:
                    low = mid + 1
            return low

        #start from here
        n1, n2 = len(nums1), len(nums2)
        p10 = bisect_left(nums1, 0) #nums1中有p10个负数
        p11 = bisect_left(nums1, 1) #nums1中有n1-p11个正数
        p20 = bisect_left(nums2, 0) #nums2中有p20个负数
        p21 = bisect_left(nums2, 1) #nums2中有n2-p21个正数

        neg = p10 * (n2 - p21) + p20 * (n1 - p11)  # 负数乘积个数
        pos = p10 * p20 + (n1 - p11) * (n2 - p21)  # 正数乘积个数
        zero = n1 * n2 - neg - pos  # 乘积为 0 的个数

        pos1, pos2 = nums1[p11:], nums2[p21:]
        neg1, neg2 = nums1[:p10], nums2[:p20]
        
        # 将 neg2 中的元素取负值
        neg2 = [-x for x in neg2]
        neg2.reverse()

        k = K
        if k <= neg:
            k = neg - k + 1 #为什么是neg-(k-1)
            #总共neg个数,分成两部分k个数和neg-k个数。neg-k+1个数。
            return -binary_search_neg(neg1, neg2, pos1, pos2, k)

        if k <= neg + zero:
            return 0

        k -= neg + zero
        return binary_search_pos(neg1, neg2, pos1, pos2, k)

题目682386. 找出数组的第 K 大和

解题思路:思维转换+二分。

C++代码如下,

class Solution {
public:
    long long kSum(vector<int>& nums, int k) {
        //最大的子序列的和
        long long s = 0;
        for (int i = 0; i < nums.size(); ++i) {
            if (nums[i] > 0) {
                s += nums[i];
            } else {
                nums[i] = -nums[i];
            }
        }
        sort(nums.begin(), nums.end()); //排序nums数组

        ///求数组nums中第k小的子序列的和
        function<bool(long long, long long)> check =[&](long long mid, long long k) -> bool {
            long long cnt = 1;
            function<void(int,long long)> dfs =[&] (int i, long long curr) -> void {
                if (cnt == k || i == nums.size() || curr + nums[i] > mid) {
                    return;
                }
                cnt += 1;
                dfs(i+1, curr+nums[i]);
                dfs(i+1, curr);
                return;
            };
            dfs(0,0);
            return cnt == k;
        };

        long long left = 0;
        long long right = 0;
        for (auto x : nums) right += x;
        long long res = -1;
        while (left <= right) {
            long long mid = (left + right) / 2;
            if (check(mid, k)) {
                res = mid;
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
        return s - res;
    }
};

python3代码如下,

class Solution:
    def kSum(self, nums: List[int], k: int) -> int:
        #最大的数组和
        s = 0
        for i in range(len(nums)):
            if nums[i] > 0:
                s += nums[i]
            else:
                nums[i] = -nums[i]
        nums.sort() #排序nums数组
        ###在nums中找到第k小的子序列和
        def check(mid: int, k: int) -> bool:
            #从nums中找到小于等于mid的子序列和,这样的子序列个数>=k
            cnt = 1
            def dfs(i: int, curr: int) -> None:
                nonlocal cnt 
                if cnt == k or i == len(nums) or curr + nums[i] > mid:
                    return 
                cnt += 1
                dfs(i+1, curr+nums[i])
                dfs(i+1, curr)
                return 
            dfs(0,0)
            return cnt == k 

        left = 0
        right = sum(nums)
        res = -1
        while left <= right:
            mid = (left + right) // 2
            if check(mid, k):
                res = mid 
                right = mid - 1
            else:
                left = mid + 1
        
        return s - res 

题目691508. 子数组和排序后的区间和

解题思路:简单题简单做。

C++代码如下,

class Solution {
public:
    int rangeSum(vector<int>& nums, int n, int left, int right) {
        const int mod = 1e9 + 7;
        vector<int> res;
        for (int i = 0; i < n; ++i) {
            int curr = 0;
            for (int j = i; j < n; ++j) {
                curr += nums[j];
                res.push_back(curr);
            }
        }
        sort(res.begin(), res.end());
        int ans = 0;
        for (int i = left-1; i < right; ++i) {
            ans += res[i] % mod;
            ans %= mod;
        }
        return ans;
    }
};

python3代码如下,

class Solution:
    def rangeSum(self, nums: List[int], n: int, left: int, right: int) -> int:
        #子数组的和
        res = []
        n = len(nums)
        for i in range(n):
            curr = 0
            for j in range(i,n):
                curr += nums[j]
                res.append(curr)
        res.sort()
        ans = sum(res[left-1:right])
        ans %= int(1e9+7)
        return ans 

2.8 其它

题目702476. 二叉搜索树最近节点查询

解题思路:二分。

C++代码如下,

/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
 * };
 */
class Solution {
public:
    vector<vector<int>> closestNodes(TreeNode* root, vector<int>& queries) {
        vector<int> nums;
        function<void(TreeNode*)> dfs = [&] (TreeNode* node) -> void {
            if (node == nullptr) {
                return;
            }
            dfs(node->left);
            nums.push_back(node->val);
            dfs(node->right);
            return;
        };
        dfs(root);
        vector<vector<int>> res;
        for (auto query : queries) {
            vector<int> t = {-1,-1};
            auto iter = lower_bound(nums.begin(), nums.end(), query+1);
            int idx1 = distance(nums.begin(), iter);
            idx1 -= 1;
            if (0 <= idx1 && idx1 < nums.size()) {
                t[0] = nums[idx1];
            }
            iter = lower_bound(nums.begin(), nums.end(), query);
            int idx2 = distance(nums.begin(), iter);
            if (0 <= idx2 && idx2 < nums.size()) {
                t[1] = nums[idx2];
            }
            res.push_back(t);
        }
        return res;
    }
};

python3代码如下,

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def closestNodes(self, root: Optional[TreeNode], queries: List[int]) -> List[List[int]]:
        nums = []
        def dfs(node: TreeNode) -> None:
            if node is None:
                return 
            dfs(node.left)
            nums.append(node.val)
            dfs(node.right)
            return 
        dfs(root)
        res = []
        for query in queries:
            t = [-1, -1]
            idx1 = bisect.bisect_left(nums, query+1)
            idx1 -= 1
            if 0 <= idx1 < len(nums):
                t[0] = nums[idx1]
            idx2 = bisect.bisect_left(nums, query)
            if 0 <= idx2 < len(nums):
                t[1] = nums[idx2] 
            res.append(t)
        return res 

题目7174. 搜索二维矩阵

解题思路:二分,模拟bisect_left

C++代码如下,

class Solution {
public:
    bool searchMatrix(vector<vector<int>>& matrix, int target) {
        int n = matrix.size();
        int m = matrix[0].size();
        int left = 0;
        int right = n * m - 1;
        int res = -1;

        function<bool(int)> check =[&](int mid) ->bool {
            int i = mid / m;
            int j = mid % m;
            return matrix[i][j] >= target;
        };

        while (left <= right) {
            int mid = (left + right) / 2;
            if (check(mid)) {
                res = mid;
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
        if (res == -1) return false;
        else {
            int i = res / m;
            int j = res % m;
            return matrix[i][j] == target;
        }
    }
};

python3代码如下,

class Solution:
    def searchMatrix(self, matrix: List[List[int]], target: int) -> bool:
        n = len(matrix)
        m = len(matrix[0])
        left = 0
        right = n * m - 1
        res = -1

        def check(mid: int) -> bool:
            i = mid // m 
            j = mid % m 
            return matrix[i][j] >= target 

        #模拟bisect_left
        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                res = mid 
                right = mid - 1
            else:
                left = mid + 1
        if res == -1:
            return False
        else:
            i = res // m 
            j = res % m 
            return matrix[i][j] == target  

题目72240. 搜索二维矩阵 II

解题思路:二分。

C++代码如下,

class Solution {
public:
    bool searchMatrix(vector<vector<int>>& matrix, int target) {
        int n = matrix.size();
        int m = matrix[0].size();

        int i = n-1;
        int j = 0;
        while (i >= 0 && j < m) {
            if (matrix[i][j] == target) {
                return true;
            } else if (matrix[i][j] < target) {
                j += 1;
            } else {
                i -= 1;
            }
        }
        return false;
    }
};

python3代码如下,

class Solution:
    def searchMatrix(self, matrix: List[List[int]], target: int) -> bool:
        n = len(matrix)
        m = len(matrix[0])

        i = n - 1
        j = 0
        while i >= 0 and j < m:
            if matrix[i][j] == target:
                return True 
            elif matrix[i][j] < target:
                j += 1
            else:
                i -= 1
        return False 

题目73278. 第一个错误的版本

解题思路:二分。

C++代码如下,

// The API isBadVersion is defined for you.
// bool isBadVersion(int version);

class Solution {
public:
    int firstBadVersion(int n) {
        int left = 1;
        int right = n;
        int res = -1;
        while (left <= right) {
            int mid = (1ll * left + right) / 2;
            if (isBadVersion(mid)) {
                res = mid;
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
        return res;
    }
};

python3代码如下,

# The isBadVersion API is already defined for you.
# def isBadVersion(version: int) -> bool:

class Solution:
    def firstBadVersion(self, n: int) -> int:
        left = 1
        right = n
        res = -1
        while left <= right:
            mid = (left + right) // 2
            if isBadVersion(mid):
                res = mid 
                right = mid - 1
            else:
                left = mid + 1
        return res 
        

题目74374. 猜数字大小

解题思路:二分。

C++代码如下,

/** 
 * Forward declaration of guess API.
 * @param  num   your guess
 * @return 	     -1 if num is higher than the picked number
 *			      1 if num is lower than the picked number
 *               otherwise return 0
 * int guess(int num);
 */

class Solution {
public:
    int guessNumber(int n) {
        int left = 1;
        int right = n;
        while (left <= right) {
            int mid = (1ll * left + right) / 2;
            int x = guess(mid);
            if (x == 0) return mid;
            else if (x == -1) right = mid - 1;
            else left = mid + 1;
        }
        return -100;
    }
};

python3代码如下,

# The guess API is already defined for you.
# @param num, your guess
# @return -1 if num is higher than the picked number
#          1 if num is lower than the picked number
#          otherwise return 0
# def guess(num: int) -> int:

class Solution:
    def guessNumber(self, n: int) -> int:
        left = 1
        right = n 
        while left <= right:
            mid = (left + right) // 2
            x = guess(mid)
            if x == 0:
                return mid 
            elif x == -1:
                right = mid - 1
            else:
                left = mid + 1
        return -100

题目75162. 寻找峰值

解题思路:二分。

C++代码如下,

class Solution {
public:
    int findPeakElement(vector<int>& nums) {
        int n = nums.size();
        int left = 0;
        int right = n - 1;
        int res = -1;

        function<bool(int)> check =[&] (int x) -> bool {
            if (x-1 >= 0) return nums[x-1] < nums[x];
            else return true;
        };

        while (left <= right) {
            int mid = (1ll * left + right) / 2;
            if (check(mid)) {
                res = mid;
                left = mid + 1;
            } else {
                right = mid - 1;
            }
        }
        return res;
    }
};

python3代码如下,

class Solution:
    def findPeakElement(self, nums: List[int]) -> int:
        n = len(nums)
        left = 0
        right = n-1
        res = -1

        def check(mid: int) -> bool:
            if mid-1 >= 0:
                return nums[mid-1] < nums[mid]
            else:
                return True

        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                res = mid #求最大值
                left = mid + 1
            else:
                right = mid - 1
        return res 

题目761901. 寻找峰值 II

解题思路:转换思维,二分。

C++代码如下,

class Solution {
public:
    vector<int> findPeakGrid(vector<vector<int>>& mat) {
        int n = mat.size();
        int m = mat[0].size();
        
        function<bool(int)> check =[&] (int i) -> bool {
            int mj = 0;
            for (int j = 1; j < m; ++j) {
                if (mat[i][j] > mat[i][mj]) {
                    mj = j;
                }
            }
            return i-1 == -1 || mat[i-1][mj] <= mat[i][mj];
        };

        int resi = -1;
        int left = 0;
        int right = n-1;
        while (left <= right) {
            int mid = (left + right) / 2;
            if (check(mid)) {
                resi = mid; //求最最大值
                left = mid + 1;
            } else {
                right = mid - 1;
            }
        }

        int resj = 0;
        for (int j = 1; j < m; ++j) {
            if (mat[resi][j] > mat[resi][resj]) {
                resj = j;
            }
        }
        return {resi, resj};
    }
};

python3代码如下,

class Solution:
    def findPeakGrid(self, mat: List[List[int]]) -> List[int]:
        #二分每一行的最大值
        n = len(mat)
        m = len(mat[0])
        left = 0
        right = n - 1
        resi = -1

        def check(i: int) -> bool:
            my = max(mat[i])
            j = mat[i].index(my)
            return i-1 == -1 or mat[i-1][j] <= mat[i][j]

        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                resi = mid 
                left = mid + 1
            else:
                right = mid - 1
        my = max(mat[resi])
        resj = mat[resi].index(my)
        return [resi, resj]

题目77852. 山脉数组的峰顶索引

解题思路:二分。

C++代码如下,

class Solution {
public:
    int peakIndexInMountainArray(vector<int>& arr) {
        int n = arr.size();
        int left = 0;
        int right = n-1;
        int res = -1;

        function<bool(int)> check =[&] (int mid) -> bool {
            return mid-1 == -1 || arr[mid-1] <= arr[mid];
        };

        while (left <= right) {
            int mid = (left + right) / 2;
            if (check(mid)) {
                res = mid;
                left = mid + 1; //最大值
            } else {
                right = mid - 1;
            }
        }
        return res;
    }
};

python3代码如下,

class Solution:
    def peakIndexInMountainArray(self, arr: List[int]) -> int:
        n = len(arr)
        left = 0
        right = n-1
        res = -1

        def check(mid: int) -> bool:
            return mid-1 == -1 or arr[mid-1] < arr[mid]

        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                res = mid 
                left = mid + 1 #求最大值
            else:
                right = mid - 1
        return res 

题目781095. 山脉数组中查找目标值

解题思路:二分。

C++代码如下,

/**
 * // This is the MountainArray's API interface.
 * // You should not implement it, or speculate about its implementation
 * class MountainArray {
 *   public:
 *     int get(int index);
 *     int length();
 * };
 */

class Solution {
public:
    int findInMountainArray(int target, MountainArray &mountainArr) {
        //求峰顶元素下标
        int peek = -1;
        int left = 0;
        int right = mountainArr.length() - 1;
        while (left <= right) {
            int mid = (left + right) / 2;
            if (mid-1 == -1 || mountainArr.get(mid-1) < mountainArr.get(mid)) {
                peek = mid;
                left = mid + 1; //求最大值
            } else {
                right = mid - 1;
            }
        }

        //[0,peek]求目标值target
        left = 0;
        right = peek;
        int idx1 = -1;
        while (left <= right) {
            int mid = (left + right) / 2;
            if (mountainArr.get(mid) <= target) {
                idx1 = mid;
                left = mid + 1;
            } else {
                right = mid - 1;
            }
        }
        if (idx1 != -1 && mountainArr.get(idx1) == target) {
            return idx1;
        }

        //[peek,n-1]求目标值target
        left = peek;
        right = mountainArr.length() - 1;
        int idx2 = -1;
        while (left <= right) {
            int mid = (left + right) / 2;
            if (mountainArr.get(mid) <= target) {
                idx2 = mid;
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
        if (idx2 != -1 && mountainArr.get(idx2) == target) {
            return idx2;
        }
        return -1;
    }
};

python3代码如下,

# """
# This is MountainArray's API interface.
# You should not implement it, or speculate about its implementation
# """
#class MountainArray:
#    def get(self, index: int) -> int:
#    def length(self) -> int:

class Solution:
    def findInMountainArray(self, target: int, mountain_arr: 'MountainArray') -> int:
        peek = -1 #顶峰坐标
        left = 0
        right = mountain_arr.length()-1
        
        def check(mid: int) -> bool:
            return mid-1 == -1 or mountain_arr.get(mid-1) <= mountain_arr.get(mid)

        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                peek = mid #求最大值
                left = mid + 1
            else:
                right = mid - 1

        #[0,peek]求元素
        idx1 = -1
        left = 0
        right = peek
        while left <= right:
            mid = (left + right) // 2
            if mountain_arr.get(mid) <= target:
                idx1 = mid 
                left = mid + 1
            else:
                right = mid - 1
        if idx1 != -1 and mountain_arr.get(idx1) == target:
            return idx1 
        #[peek,n-1]求元素
        idx2 = -1 
        left = peek
        right = mountain_arr.length()-1
        while left <= right:
            mid = (left + right) // 2
            if mountain_arr.get(mid) <= target:
                idx2 = mid 
                right = mid - 1
            else:
                left = mid + 1
        if idx2 != -1 and mountain_arr.get(idx2) == target:
            return idx2
        return -1

题目79153. 寻找旋转排序数组中的最小值

解题思路:小技巧,二分。

C++代码如下,

class Solution {
public:
    int findMin(vector<int>& nums) {
        int n = nums.size();
        int left = 0;
        int right = n - 1;
        int idx = -1; //转折点

        function<bool(int)> check =[&] (int mid) -> bool {
            return (mid-1 == -1 || nums[mid-1] <= nums[mid]) && nums[0] <= nums[mid];
        };

        while (left <= right) {
            int mid = (left + right) / 2;
            if (check(mid)) {
                idx = mid;
                left = mid + 1; //求最大值
            } else {
                right = mid - 1;
            }
        }
        return nums[(idx+1)%n];
    }
};

python3代码如下,

class Solution:
    def findMin(self, nums: List[int]) -> int:
        #找到peek元素
        peek = -1
        n = len(nums)
        left = 0
        right = n-1
        while left <= right:
            mid = (left + right) // 2
            if (mid-1 == -1 or nums[mid-1] <= nums[mid]) and nums[0] <= nums[mid]:
                peek = mid 
                left = mid + 1 #求最大值
            else:
                right = mid - 1
        #print(f"peek = {peek}.")
        return nums[(peek+1) % n] 

题目8033. 搜索旋转排序数组

解题思路:二分。

C++代码如下,

class Solution {
public:
    int search(vector<int>& nums, int target) {
        int n = nums.size();
        int left = 0;
        int right = n-1;
        int idx = -1;
        while (left <= right) {
            int mid = (1ll  * left + right) / 2;
            if ((mid-1 == -1 || nums[mid-1] <= nums[mid]) && nums[0] <= nums[mid]) {
                idx = mid;
                left = mid + 1; //求最大值
            } else {
                right = mid - 1;
            }
        }

        //[0,idx]求目标值target
        left = 0;
        right = idx;
        int ans1 = -1;
        while (left <= right) {
            int mid = (1ll * left + right) / 2;
            if (nums[mid] <= target) {
                ans1 = mid;
                left = mid + 1;
            } else {
                right = mid - 1;
            }
        }
        if (ans1 != -1 && nums[ans1] == target) {
            return ans1;
        }

        //[idx+1,n-1]求目标值target
        left = idx + 1;
        right = n - 1;
        int ans2 = -1;
        while (left <= right) {
            int mid = (1ll * left + right) / 2;
            if (nums[mid] <= target) {
                ans2 = mid;
                left = mid + 1;
            } else {
                right = mid - 1;
            }
        }
        if (ans2 != -1 && nums[ans2] == target) {
            return ans2;
        }
        return -1;
    }
};

python3代码如下,

class Solution:
    def search(self, nums: List[int], target: int) -> int:
        #找到转折点
        n = len(nums)
        left = 0
        right = n - 1
        idx = -1

        def check(mid: int) -> bool:
            return (mid-1 == -1 or nums[mid-1] < nums[mid]) and nums[0] <= nums[mid]

        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                idx = mid 
                left = mid + 1
            else:
                right = mid - 1
        
        #print(f"idx = {idx}.")
        
        #[0,idx]找目标值target
        ans1 = -1
        left = 0
        right = idx 
        #print(f"step1: left = {left}, right = {right}.")
        while left <= right:
            mid = (left + right) // 2
            if nums[mid] <= target:
                ans1 = mid 
                left = mid + 1
            else:
                right = mid - 1
        if ans1 != -1 and nums[ans1] == target:
            return ans1 
        
        #[idx+1,n-1]找目标值target
        ans2 = -1
        left = idx + 1 
        right = n - 1
        #print(f"step2: left = {left}, right = {right}.")
        while left <= right:
            mid = (left + right) // 2
            if nums[mid] <= target:
                ans2 = mid 
                left = mid + 1
            else:
                right = mid - 1
        if ans2 != -1 and nums[ans2] == target:
            return ans2 
        return -1

题目81222. 完全二叉树的节点个数

解题思路:简单题简单做。

C++代码如下,

/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
 * };
 */
class Solution {
public:
    int countNodes(TreeNode* root) {
        int cnt = 0;
        function<void(TreeNode* node)> dfs =[&] (TreeNode* node) -> void {
            if (node == nullptr) return;
            dfs(node->left);
            cnt += 1;
            dfs(node->right);
            return;
        };
        dfs(root);
        return cnt;
    }
};

python3代码如下,

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def countNodes(self, root: Optional[TreeNode]) -> int:
        #简单题简单做
        cnt = 0
        def dfs(node: TreeNode) -> None:
            nonlocal cnt 
            if node is None:
                return 
            dfs(node.left)
            cnt += 1
            dfs(node.right)
            return 
        dfs(root)
        return cnt 

题目821539. 第 k 个缺失的正整数

解题思路:简单题简单做。

C++代码如下,

class Solution {
public:
    int findKthPositive(vector<int>& arr, int k) {
        set<int> brr(arr.begin(), arr.end());
        int cnt = 0;
        int i = 1;
        while (true) {
            if (brr.find(i) == brr.end()) {
                cnt += 1;
                if (cnt == k) {
                    return i;
                }
            }
            i += 1;
        }
        return -100;
    }
};

python3代码如下,

class Solution:
    def findKthPositive(self, arr: List[int], k: int) -> int:
        #简单题简单做
        arr = set(arr)
        i = 1
        cnt = 0
        while True:
            if i not in arr:
                cnt += 1
                if cnt == k:
                    return i 
            i += 1
        return -100

题目83540. 有序数组中的单一元素

解题思路:二分。

C++代码如下,

class Solution {
public:
    int singleNonDuplicate(vector<int>& a) {
        int n = a.size();

        if (n == 1) { //特判
            return a[0]; 
        }

        int left = 0;
        int right = n - 1;
        int res = -1;

        function<bool(int)> check =[&] (int mid) -> bool {
            if (mid % 2 == 0) {
                return (mid+1 < n && a[mid] != a[mid+1]) || (mid+1 == n && a[mid] == a[mid-1]);
            } else {
                return a[mid] == a[mid+1];
            }
        };

        while (left <= right) {
            int mid = (left + right) / 2;
            if (check(mid)) {
                res = mid;
                right = mid - 1; //求最小值
            } else {
                left = mid + 1;
            }
        }
        if (res == -1) {
            res = n - 1; //特判[1,1,2]这种情况
        }
        return a[res];
    }
};

python3代码如下,

class Solution:
    def singleNonDuplicate(self, a: List[int]) -> int:
        n = len(a)
        left = 0
        right = n - 1
        res = -1

        def check(mid: int) -> bool:
            if mid % 2 == 0:
                return (mid+1 < n and a[mid] != a[mid+1]) or (mid+1 == n and a[mid] == a[mid-1])
            else:
                return a[mid] == a[mid+1]

        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                res = mid 
                right = mid - 1 #求最小值
            else:
                left = mid + 1
        return a[res] 

题目844. 寻找两个正序数组的中位数

解题思路:求第k小的数,二分。

C++代码如下,

class Solution {
public:

    int compute_kth_num(const vector<int>& nums1, const vector<int>& nums2, const int k) {
        //计算第k小的数
        //小于等于mid的数的个数cnt,cnt>=k
        //求mid的最小值
        int n = nums1.size();
        int m = nums2.size();
        int left = 1e7;
        int right = -1e7;
        if (n > 0) {
            left = min(left, nums1[0]);
            right = max(right, nums1[n-1]);
        }
        if (m > 0) {
            left = min(left, nums2[0]);
            right = max(right, nums2[m-1]);
        }

        function<bool(int)> check =[&] (int mid) -> bool {
            auto iter = lower_bound(nums1.begin(), nums1.end(), mid+1);
            int idx1 = distance(nums1.begin(), iter);
            iter = lower_bound(nums2.begin(), nums2.end(), mid+1);
            int idx2 = distance(nums2.begin(), iter);
            return idx1+idx2 >= k;
        };

        int res = -1;
        while (left <= right) {
            int mid = (left + right) / 2;
            if (check(mid)) {
                res = mid;
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
        return res;
    }


    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        int n = nums1.size();
        int m = nums2.size();
        int k = (n + m) / 2 + 1;
        if ((n+m) % 2 == 0) {
            int k1 = k - 1;
            int k2 = k;
            int ans1 = compute_kth_num(nums1, nums2, k1);
            int ans2 = compute_kth_num(nums1, nums2, k2);
            double ans = (1.0 * ans1 + ans2) / 2.0;
            return ans;
        } else {
            int ans = compute_kth_num(nums1, nums2, k);
            return ans;
        }
    }
};

python3代码如下,

class Solution:
    def compute_kth_num(self, nums1: List[int], nums2: List[int], k: int) -> int:
        #print(f"nums1 = {nums1}, nums2 = {nums2}, k = {k}.")
        #求第k小的数
        n = len(nums1)
        m = len(nums2)
        left = int(1e7)
        right = -1 * int(1e7)
        if n > 0:
            left = min(left, nums1[0])
            right = max(right, nums1[-1])
        if m > 0:
            left = min(left, nums2[0])
            right = max(right, nums2[-1])

        def check(mid: int) -> bool:
            idx1 = bisect_left(nums1, mid+1)
            idx2 = bisect_left(nums2, mid+1)
            return idx1+idx2 >= k 
        
        #计算<=mid的个数,它>=k
        #求mid的最小值
        res = -1
        while left <= right:
            mid = (left + right) // 2
            if check(mid):
                res = mid 
                right = mid - 1
            else:
                left = mid + 1
        #print(f"res = {res}.")
        return res 
        

    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
        n = len(nums1)
        m = len(nums2)
        k = (n + m) // 2
        #求第k小的数
        if (n + m) % 2 == 0:
            k1 = k
            k2 = k + 1 
            ans1 = self.compute_kth_num(nums1,nums2,k1) 
            ans2 = self.compute_kth_num(nums1,nums2,k2)
            ans = (ans1 + ans2) / 2
            return ans 
        else:
            ans = self.compute_kth_num(nums1, nums2, k + 1)
            return ans 

3 参考

灵神力扣题单

灵神力扣题单之二分算法(上)