def solution(n, cats_levels):
fish_amounts = [1] * n
for i in range(1, n):
if cats_levels[i] > cats_levels[i - 1]:
fish_amounts[i] = fish_amounts[i - 1] + 1
for i in range(n - 2, -1, -1):
if cats_levels[i] > cats_levels[i + 1]:
fish_amounts[i] = max(fish_amounts[i], fish_amounts[i + 1] + 1)
return sum(fish_amounts)
if __name__ == "__main__":
cats_levels1 = [1, 2, 2]
cats_levels2 = [6, 5, 4, 3, 2, 16]
cats_levels3 = [1, 2, 2, 3, 3, 20, 1, 2, 3, 3, 2, 1, 5, 6, 6, 5, 5, 7, 7, 4]
print(solution(3, cats_levels1) == 4)
print(solution(6, cats_levels2) == 17)
print(solution(20, cats_levels3) == 35)