题目描述
- 竞赛积分
2105
思路
我是用BKDRHash做的。 我这里只讲
- 水平方向怎么在母串中找匹配上的模式串。(竖直方向同理,所以这里省略)
- 找到模式串后进行标记并且需要
合并区间 - bytearray的用法
- 获取竖直情况下的字符串区间 -> 遍历竖直情况下区间坐标 -> 映射回二维坐标点 -> 映射回水平情况的一维坐标 掌握以上四点,这道题就可以用python AC掉了。
首先,第一如何在水平方向上找到匹配字符串。 你先通过BKDRHash算法计算出模式串的哈希值。另外将grid矩阵的字符串通过水平方向进行拼接形成大字符串,在这个母串上用滑动窗口进行滑动,遍历右端点,移入r字符,如果窗口大小大于模式串的长度,那么就移除l字符。时时刻刻查看当前hash是否等于模式串的hash。如果等于那么说明我们在母串上找到了一个模式串,记录当前的左右端点。
第二,因为你找到的区间可能有大量的重叠,为了降低后续的操作次数,我们需要进行一个简单的区间合并。区间合并是hot100的经典题目了,默写一遍就行。
第三,针对已经完成合并的区间,我们将它的水平坐标区间记录在bytearray中。
h_hits[curr_l : curr_r + 1] = b'\x01' * (curr_r - curr_l + 1) 是一种切片操作,b'\x01'是byte的字面量,这行代码表示从[curr_l : curr_r + 1]左闭右开区间全部用1填充。
最后,你同理炮制竖直情况的时候,会获得完成竖直区间合并的区间,你需要遍历这个区间的下标
matrix_r = 竖直情况的下标 % R ; matrix_c = 竖直情况的下标 // R
你得到了二维坐标点,然后再映射会水平方向一维坐标点
水平情况的下标 = matrix_r * C + matrix_c
然后我们看看bytearray的这个位置,是不是已经标记成1了
if h_hits[水平情况的下标] == 1
如果是的,那么答案加1.
代码
from typing import List
class Solution:
def countCells(self, grid: List[List[str]], pattern: str) -> int:
R, C = len(grid), len(grid[0])
n = len(pattern)
h_hits = bytearray(R * C)
P = 131
MOD = 2**64
p = [1] * (n + 1)
for i in range(1, n + 1):
p[i] = (p[i-1] * P) % MOD
# 计算 pattern 的 hash
pp = 0
for char in pattern:
pp = (pp * P + (ord(char) - ord('a') + 1)) % MOD
# --- 水平方向处理 ---
matrix_str = ''.join([''.join(r) for r in grid])
l = 0
cur_hash = 0
intervals = []
for r in range(R * C):
# 滚动哈希计算
val_in = ord(matrix_str[r]) - ord('a') + 1
cur_hash = (cur_hash * P + val_in) % MOD
if r - l + 1 > n:
val_out = ord(matrix_str[l]) - ord('a') + 1
cur_hash = (cur_hash - val_out * p[n]) % MOD
l += 1
if cur_hash == pp:
intervals.append((l, r))
# 合并水平区间并标记
if intervals:
intervals.sort()
curr_l, curr_r = intervals[0]
for i in range(1, len(intervals)):
next_l, next_r = intervals[i]
if next_l > curr_r:
h_hits[curr_l : curr_r + 1] = b'\x01' * (curr_r - curr_l + 1)
curr_l, curr_r = next_l, next_r
else:
curr_r = max(curr_r, next_r)
h_hits[curr_l : curr_r + 1] = b'\x01' * (curr_r - curr_l + 1)
# --- 垂直方向处理 ---
matrix_str = ''.join([''.join(s) for s in zip(*grid)])
l = 0
cur_hash = 0
intervals = []
for r in range(R*C):
val_in = ord(matrix_str[r]) - ord('a') + 1
cur_hash = (cur_hash * P + val_in) % MOD
if r - l + 1 > n:
val_out = ord(matrix_str[l]) - ord('a') + 1
cur_hash = (cur_hash - val_out * p[n]) % MOD
l += 1
if cur_hash == pp:
intervals.append((l, r))
def count_range(start, end):
count = 0
for k in range(start, end + 1):
row = k % R
col = k // R
original_idx = row * C + col
if h_hits[original_idx] == 1:
count += 1
return count
ans = 0
if intervals:
intervals.sort()
curr_l, curr_r = intervals[0]
for i in range(1, len(intervals)):
next_l, next_r = intervals[i]
if next_l > curr_r:
ans += count_range(curr_l, curr_r)
curr_l, curr_r = next_l, next_r
else:
curr_r = max(curr_r, next_r)
ans += count_range(curr_l, curr_r)
return ans