【力扣roadmap】3529. 统计水平子串和垂直子串重叠格子的数目

21 阅读3分钟

题目描述

  • 竞赛积分2105

image.png

思路

我是用BKDRHash做的。 我这里只讲

  • 水平方向怎么在母串中找匹配上的模式串。(竖直方向同理,所以这里省略)
  • 找到模式串后进行标记并且需要合并区间
  • bytearray的用法
  • 获取竖直情况下的字符串区间 -> 遍历竖直情况下区间坐标 -> 映射回二维坐标点 -> 映射回水平情况的一维坐标 掌握以上四点,这道题就可以用python AC掉了。

首先,第一如何在水平方向上找到匹配字符串。 你先通过BKDRHash算法计算出模式串的哈希值。另外将grid矩阵的字符串通过水平方向进行拼接形成大字符串,在这个母串上用滑动窗口进行滑动,遍历右端点,移入r字符,如果窗口大小大于模式串的长度,那么就移除l字符。时时刻刻查看当前hash是否等于模式串的hash。如果等于那么说明我们在母串上找到了一个模式串,记录当前的左右端点。

第二,因为你找到的区间可能有大量的重叠,为了降低后续的操作次数,我们需要进行一个简单的区间合并。区间合并是hot100的经典题目了,默写一遍就行。

第三,针对已经完成合并的区间,我们将它的水平坐标区间记录在bytearray中。 h_hits[curr_l : curr_r + 1] = b'\x01' * (curr_r - curr_l + 1) 是一种切片操作,b'\x01'是byte的字面量,这行代码表示从[curr_l : curr_r + 1]左闭右开区间全部用1填充。

最后,你同理炮制竖直情况的时候,会获得完成竖直区间合并的区间,你需要遍历这个区间的下标 matrix_r = 竖直情况的下标 % R ; matrix_c = 竖直情况的下标 // R 你得到了二维坐标点,然后再映射会水平方向一维坐标点 水平情况的下标 = matrix_r * C + matrix_c 然后我们看看bytearray的这个位置,是不是已经标记成1了 if h_hits[水平情况的下标] == 1 如果是的,那么答案加1.

代码

from typing import List

class Solution:
    def countCells(self, grid: List[List[str]], pattern: str) -> int:
        R, C = len(grid), len(grid[0])
        n = len(pattern)
        
        h_hits = bytearray(R * C) 
        
        P = 131
        MOD = 2**64 
        
        p = [1] * (n + 1)
        for i in range(1, n + 1):
            p[i] = (p[i-1] * P) % MOD
            
        # 计算 pattern 的 hash
        pp = 0
        for char in pattern:
            pp = (pp * P + (ord(char) - ord('a') + 1)) % MOD
            
        # --- 水平方向处理 ---
        matrix_str = ''.join([''.join(r) for r in grid])
        l = 0
        cur_hash = 0
        intervals = []
        for r in range(R * C):
            # 滚动哈希计算
            val_in = ord(matrix_str[r]) - ord('a') + 1
            cur_hash = (cur_hash * P + val_in) % MOD
            
            if r - l + 1 > n:
                val_out = ord(matrix_str[l]) - ord('a') + 1
                cur_hash = (cur_hash - val_out * p[n]) % MOD
                l += 1
            
            if cur_hash == pp:
                intervals.append((l, r))

        # 合并水平区间并标记
        if intervals:
            intervals.sort()
            curr_l, curr_r = intervals[0]
            for i in range(1, len(intervals)):
                next_l, next_r = intervals[i]
                if next_l > curr_r:
                    h_hits[curr_l : curr_r + 1] = b'\x01' * (curr_r - curr_l + 1)
                    curr_l, curr_r = next_l, next_r
                else:
                    curr_r = max(curr_r, next_r)
            h_hits[curr_l : curr_r + 1] = b'\x01' * (curr_r - curr_l + 1)

        # --- 垂直方向处理 ---
        matrix_str = ''.join([''.join(s) for s in zip(*grid)])
        l = 0
        cur_hash = 0
        intervals = []
        
        for r in range(R*C):
            val_in = ord(matrix_str[r]) - ord('a') + 1
            cur_hash = (cur_hash * P + val_in) % MOD
            
            if r - l + 1 > n:
                val_out = ord(matrix_str[l]) - ord('a') + 1
                cur_hash = (cur_hash - val_out * p[n]) % MOD
                l += 1
            
            if cur_hash == pp:
                intervals.append((l, r))
        

        def count_range(start, end):
            count = 0
            for k in range(start, end + 1):
                row = k % R
                col = k // R
                original_idx = row * C + col
                if h_hits[original_idx] == 1:
                    count += 1
            return count
        ans = 0
        if intervals:
            intervals.sort()
            curr_l, curr_r = intervals[0]
            for i in range(1, len(intervals)):
                next_l, next_r = intervals[i]
                if next_l > curr_r:
                    ans += count_range(curr_l, curr_r)
                    curr_l, curr_r = next_l, next_r
                else:
                    curr_r = max(curr_r, next_r)
            ans += count_range(curr_l, curr_r)
        return ans