一、 为什么学这个?
在最近的医学内窥镜图像超分辨率项目中,我采用了“广撒网”的定期存档策略,保存了 latest、best 以及多个 epoch_xxx 权重。但在挑选最终的落地模型时,我遇到了极其痛苦的工程阻力:
- 4K 高清图的“反人类”对比:测试图分辨率高达 3840x2160。把多个模型的输出简单拼成网格图(Grid),缩小看毫无区别;放大后又极难在多张子图间精准对齐同一坐标(比如观察某一根微血管的边缘),来回拖拽对比非常低效。
- 主观先入为主的偏见:人脑会有潜意识,盯着“epoch_200”或“best”标签看时,往往会忽略其真实的视觉伪影。我们需要真正的“盲测”。
- 数据集目录结构复杂:医学数据集通常按患者或视频段划分子文件夹,常规的遍历脚本会将所有输出拍平到同一层级,导致同名图像互相覆盖。
为了彻底解决这些评估瓶颈,我决定花时间打造一个一劳永逸的自动化流水线脚本。
二、 核心内容与实现步骤
整个评测流水线分为三个核心模块:批量推理引擎、高级对比图生成器、交互式 HTML 生成器。
1. 批量推理引擎:相对路径映射与缓存优化
为了完美复刻测试集复杂的嵌套目录,我弃用了简单的 glob,改用 pathlib.Path.rglob 进行递归扫描,并通过计算相对路径来重建输出目录。
- 参数正则解析:通过
re.search直接从.pth文件名中提取网络深度、注意力类型等超参数,动态实例化模型,彻底告别手动传参。 - 目录结构克隆:使用
os.path.relpath(img_path, input_dir)获取相对路径,确保model_A/patient_1/01.png的结构在输出端完美重现。 - Bicubic 内存缓存:基线图像(Bicubic)对所有模型都是一致的,我引入了
bicubic_cache字典进行缓存,避免了每测一个模型就做一次耗时的双三次插值计算。
2. 高阶可视化:ROI 裁剪与误差热力图
相比于简单的整图拼接,我引入了 OpenCV 和 Matplotlib 来深挖图像细节:
- ROI 局部精准裁剪:脚本支持传入
crop_box=(x, y, w, h),自动从庞大的 4K 结果中“抠”出病灶/微血管区域进行并排展示。 - 误差热力图 (Error Heatmap) :利用
cv2.absdiff计算生成图与 HR (Ground Truth) 的绝对像素误差,并用cv2.COLORMAP_JET转为热力图。深红代表伪影重,深蓝代表恢复完美,优劣一目了然。
3. 本地零依赖 HTML 交互盲测工具(核心亮点)
这是整个脚本最“极客”的部分。我用一段原生 HTML+JS 代码,配合 Python 的 json 和 shutil 模块,自动生成了一个静态网页交互评测工具:
- 原位拉片切换:将 Bicubic、HR 和所有微调 epoch 的输出图堆叠在同一坐标,通过键盘
←→方向键瞬间切换。因为像素绝对对齐,多余的伪影在人眼中会像“闪烁”一样刺眼。 - 原生 CSS 放大镜:在 HTML 中加入了一个随鼠标移动的局部放大镜。利用 CSS 的
background-image和 JS 动态计算background-position,实现了极其流畅的 1.5x - 8x 实时放大查看功能。 - 一键盲测 (Blind Test) :加入 Checkbox,勾选后所有模型名称被隐藏替换为
Model 1、Model 2,强迫自己只关注图像的医学真实度进行客观挑选。
三、 遇到的问题与解决方法
坑 1:多子文件夹下同名图片的覆盖灾难
- 问题描述:使用平铺遍历时,
folder_A/001.png会覆盖掉folder_B/001.png。同时,Matplotlib 在保存名称包含/的路径时会抛出异常。 - 解决方法:在生成 HTML 时,利用
rel_path重建物理文件夹结构;在生成平面对比的 Grid 图时,使用safe_name = rel_path.replace(os.sep, '__')将路径安全地转化为单层文件名(如folder_A__001.png)。
坑 2:HTML 跨域加载本地图片失败
- 问题描述:如果 HTML 中的 JavaScript 直接读取外部绝对路径的图片,在某些现代浏览器中会触发本地 CORS 安全限制。
- 解决方法:在 Python 端不仅生成 HTML,还通过
shutil.copy将本次测试涉及的所有图片统一物理拷贝到Interactive_Viewer独立目录下。这样整个文件夹就是一个便携的静态网站,可以直接压缩发给合作者评估。
坑 3:局部放大镜的焦点对齐偏移
-
问题描述:初期实现放大镜时,鼠标中心点对应的图像位置与放大镜内显示的位置有偏差。
-
解决方法:需要精确推算背景图像的偏移量。核心公式为:
bgX = -(x * zoomLevel - glassEl.offsetWidth / 2),通过减去放大镜自身宽高的一半,强制让焦点居中。
四、 收获与总结
这次造轮子的经历让我深刻体会到:在算法工程中,“怎么评估模型”往往比“怎么训练模型”更重要。
从写出简单的面条代码,到引入 rel_path 目录映射、缓存优化显存,再到打通前后端生成全自动交互式网页,这是一个从“跑通即可”向“生产力工具”思维的转变。摆脱了“PSNR 唯分数论”后,这套流水线赋予了我通过严谨盲测和局部细节热力图来寻找最佳模型的底气。
如果你也在做底层视觉(Low-Level Vision)或医学图像恢复任务,强烈建议抛弃低效的手动看图,花半天时间为自己搭建一套专用的评估可视化流水线,绝对物超所值!
这里为你专门编写了一段 “极速上手指南(使用说明)” ,你可以直接将它作为“五、 附录:工具包上手指南” 追加到你博客的末尾。
我保持了干货、精简的风格,用最直观的代码块和目录树形式,帮助读者(包括未来的你自己)看一眼就能跑通。
五、 附录:流水线极速上手指南
如果你也想在自己的超分项目中使用这套评测流水线,可以参考以下使用说明,做到开箱即用。
1. 环境依赖
除了基础的 PyTorch 外,请确保安装了以下图像处理包:
Bash
pip install opencv-python matplotlib pillow numpy
2. 参数配置与一键运行
在 test_and_visualize.py 的 argparse 部分,你可以直接修改默认路径为你自己的本地路径。配置完成后,在终端直接一键运行即可:
Bash
python test_and_visualize.py
如果你只想针对某几张特定的“疑难杂症”图像(例如微血管最密集的图)进行测试,可以使用 --target-images 参数(支持纯文件名或相对路径过滤):
Bash
python test_and_visualize.py --target-images 01.png patient_A/05.jpg
3. 输出目录结构说明
脚本运行完毕后,你的 output-dir 将会自动生成以下结构,所有结果被安排得明明白白:
Plaintext
Visual_Results/
├── epoch_010/ # 第10轮模型输出,保持与测试集相同的子目录结构
├── epoch_best/ # Best模型输出
│
├── Visual_Comparisons/ # 静态对比画廊(科研绘图直接用)
│ ├── Full_01.png # 宏观全局对比图 + 误差热力图
│ └── Crop_01.png # ROI局部裁剪对比图 (若代码中启用了crop_box)
│
└── Interactive_Viewer/ # 交互式盲测工作站(前端网页)
├── Bicubic (Baseline)/ # 自动拷贝的基线图
├── HR (Ground Truth)/ # 自动拷贝的真实图
├── epoch_010/ # 自动拷贝的推理图
└── index.html # 双击打开即可体验拉片!
4. 盲测网页 (index.html) 操作快捷键
双击浏览器打开 index.html 后,请丢掉鼠标点击的繁琐,直接使用全键盘操作体验极致的拉片快感:
- [ ← ] / [ → ] :原位光速切换模型(Bicubic -> 各微调模型 -> HR 真值)。
- [ ↑ ] / [ ↓ ] :快速切换测试图像(无需下拉菜单)。
- 鼠标悬停:在图像上移动鼠标,自动呼出局部放大镜。
- 放大倍率调节:拖动右上角滑动条,支持
1.5x - 8.0x实时无级变焦。 - 开启盲测模式:勾选右上角 Checkbox,强制隐藏所有模型真实名称,助你排除偏见,客观挑图。
代码实现:
import argparse
import time
import os
import glob
import re
import math
import json
import shutil
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image
import matplotlib.pyplot as plt
import cv2 # 用于热力图
# 引入自定义的模型结构和工具函数
from models import ESPCN_RDB
from utils import convert_ycbcr_to_rgb, preprocess
def parse_model_kwargs(model_name):
"""从模型文件名中提取超参数"""
parsed_kwargs = {
'growth_channels': 16,
'rdb_layers': 3,
'activation': 'LeakyReLU',
'attention_type': 'pixel'
}
match_gc = re.search(r'growth_channels_(\d+)', model_name)
if match_gc:
parsed_kwargs['growth_channels'] = int(match_gc.group(1))
match_rdb = re.search(r'RDB_(\d+)', model_name)
if match_rdb:
parsed_kwargs['rdb_layers'] = int(match_rdb.group(1))
match_attn = re.search(r'Attn_(pixel|weakened_pixel|none)', model_name)
if match_attn:
parsed_kwargs['attention_type'] = match_attn.group(1)
activation_types = ['ReLU', 'LeakyReLU', 'PReLU', 'Tanh', 'Sigmoid', 'GELU']
for act in activation_types:
if f"_{act}_" in model_name or model_name.endswith(f"_{act}"):
parsed_kwargs['activation'] = act
break
return parsed_kwargs
def create_advanced_comparison(rel_path, models_list, output_root, hr_path=None, bicubic_img=None, crop_box=None):
"""高级对比图生成器:支持局部裁剪放大和误差热力图"""
print(f"🎨 正在绘制可视化对比图: {rel_path} ...")
images_dict = {}
if bicubic_img is not None:
images_dict['Bicubic'] = np.array(bicubic_img)
for model_name in models_list:
img_path = os.path.join(output_root, model_name, rel_path)
if os.path.exists(img_path):
images_dict[model_name] = np.array(pil_image.open(img_path).convert('RGB'))
hr_img_np = None
if hr_path and os.path.exists(hr_path):
hr_img_np = np.array(pil_image.open(hr_path).convert('RGB'))
images_dict['HR (GT)'] = hr_img_np
# 1. 局部裁剪处理
if crop_box is not None:
x, y, w, h = crop_box
for key in images_dict.keys():
images_dict[key] = images_dict[key][y:y + h, x:x + w, :]
if hr_img_np is not None:
hr_crop = hr_img_np[y:y + h, x:x + w, :]
# 2. 生成误差热力图
error_maps = {}
if hr_img_np is not None:
hr_gray = cv2.cvtColor(images_dict['HR (GT)'], cv2.COLOR_RGB2GRAY).astype(np.float32)
for key, img_arr in images_dict.items():
if key == 'HR (GT)': continue
pred_gray = cv2.cvtColor(img_arr, cv2.COLOR_RGB2GRAY).astype(np.float32)
diff = np.abs(hr_gray - pred_gray)
diff_norm = np.clip(diff / 50.0 * 255.0, 0, 255).astype(np.uint8)
heatmap = cv2.applyColorMap(diff_norm, cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
error_maps[key + "\nError Map"] = heatmap
final_plot_dict = {**images_dict, **error_maps}
# 3. 绘制并排网格
cols = len(images_dict)
rows = 2 if len(error_maps) > 0 else 1
fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))
if rows == 1: axes = np.array([axes])
if cols == 1: axes = axes.T
plot_keys = list(images_dict.keys())
for col_idx, key in enumerate(plot_keys):
axes[0, col_idx].imshow(images_dict[key])
short_title = key.replace('growth_channels_16_RDB_3_LeakyReLU_Attn_pixel_', '')
axes[0, col_idx].set_title(short_title, fontsize=14, fontweight='bold')
axes[0, col_idx].axis('off')
if rows == 2:
for col_idx, key in enumerate(plot_keys):
if key == 'HR (GT)':
axes[1, col_idx].axis('off')
continue
error_key = key + "\nError Map"
axes[1, col_idx].imshow(final_plot_dict[error_key])
axes[1, col_idx].set_title("Error Heatmap", fontsize=12, color='red')
axes[1, col_idx].axis('off')
plt.tight_layout()
# 将具有子文件夹属性的路径转化为下划线形式防止报错,比如 folder/01.png -> folder__01.png
safe_name = rel_path.replace(os.sep, '__').replace('/', '__')
prefix = "Crop_" if crop_box else "Full_"
vis_dir = os.path.join(output_root, "Visual_Comparisons")
os.makedirs(vis_dir, exist_ok=True)
save_path = os.path.join(vis_dir, f"{prefix}{safe_name}")
plt.savefig(save_path, dpi=200, bbox_inches='tight')
plt.close(fig)
def create_html_viewer(rel_paths, models_list, output_root, hr_dir=None, bicubic_cache=None):
print("\n🌐 正在生成交互式 HTML 盲测网页...")
html_dir = os.path.join(output_root, "Interactive_Viewer")
os.makedirs(html_dir, exist_ok=True)
final_models = []
if bicubic_cache:
final_models.append("Bicubic (Baseline)")
for rel_path, img_obj in bicubic_cache.items():
dst_path = os.path.join(html_dir, "Bicubic (Baseline)", rel_path)
os.makedirs(os.path.dirname(dst_path), exist_ok=True) # 保持子目录结构
img_obj.save(dst_path)
final_models.extend(models_list)
if hr_dir:
final_models.append("HR (Ground Truth)")
data = {"images": rel_paths, "models": final_models, "paths": {}}
for rel_path in rel_paths:
data["paths"][rel_path] = []
for model in final_models:
if model == "Bicubic (Baseline)":
web_path = f"Bicubic (Baseline)/{rel_path}".replace('\', '/')
elif model == "HR (Ground Truth)":
src_hr = os.path.join(hr_dir, rel_path)
dst_hr = os.path.join(html_dir, "HR (Ground Truth)", rel_path)
if os.path.exists(src_hr):
os.makedirs(os.path.dirname(dst_hr), exist_ok=True)
shutil.copy(src_hr, dst_hr)
web_path = f"HR (Ground Truth)/{rel_path}".replace('\', '/')
else:
src_model_img = os.path.join(output_root, model, rel_path)
dst_model_img = os.path.join(html_dir, model, rel_path)
if os.path.exists(src_model_img):
os.makedirs(os.path.dirname(dst_model_img), exist_ok=True)
shutil.copy(src_model_img, dst_model_img)
web_path = f"{model}/{rel_path}".replace('\', '/')
data["paths"][rel_path].append(web_path)
# ==============================================================
# ✨ 核心改动区域:HTML_CONTENT 中加入了放大镜 CSS 和 JS 逻辑
# ==============================================================
html_content = """
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<title>EndoMamba SR 交互式拉片对比系统 (带放大镜)</title>
<style>
body { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; background-color: #121212; color: #ffffff; text-align: center; margin: 0; padding: 0; overflow-x: hidden; }
#toolbar { padding: 15px 30px; background-color: #1e1e1e; display: flex; justify-content: space-between; align-items: center; box-shadow: 0 4px 6px rgba(0,0,0,0.3); position: relative; z-index: 10; }
.tools-left, .tools-center, .tools-right { display: flex; align-items: center; gap: 15px; }
#image-container { position: relative; display: inline-block; margin-top: 20px; max-width: 98vw; height: 85vh; cursor: crosshair; }
img { max-width: 100%; max-height: 100%; object-fit: contain; display: block;}
/* ✨ 放大镜核心样式 */
#magnifier {
position: absolute;
border: 3px solid #00ffcc;
border-radius: 50%;
cursor: none;
width: 200px; /* 放大镜窗口大小 */
height: 200px;
background-repeat: no-repeat;
box-shadow: 0 0 15px rgba(0, 0, 0, 0.8), inset 0 0 10px rgba(0,0,0,0.5);
display: none; /* 默认隐藏 */
pointer-events: none; /* 防止放大镜遮挡鼠标移动事件 */
z-index: 100;
}
select { padding: 8px; border-radius: 4px; background: #333; color: white; border: 1px solid #555; outline: none; font-size: 14px;}
.btn { background: #007acc; color: white; border: none; padding: 10px 20px; cursor: pointer; font-size: 14px; border-radius: 4px; font-weight: bold; transition: background 0.2s;}
.btn:hover { background: #005f9e; }
#model-badge { position: absolute; top: 15px; left: 15px; background: rgba(0, 0, 0, 0.7); padding: 10px 20px; font-size: 24px; font-weight: bold; border-radius: 6px; color: #00ffcc; pointer-events: none; border: 1px solid rgba(255,255,255,0.2); backdrop-filter: blur(4px); z-index: 10;}
.switch-container { display: flex; align-items: center; font-size: 14px; font-weight: bold; color: #e74c3c; background: rgba(231, 76, 60, 0.1); padding: 8px 15px; border-radius: 20px; border: 1px solid #e74c3c;}
input[type="checkbox"] { width: 18px; height: 18px; cursor: pointer; margin-right: 8px; }
.zoom-control { display: flex; align-items: center; font-size: 14px; font-weight: bold; background: #333; padding: 8px 15px; border-radius: 20px; border: 1px solid #555; }
.zoom-control input { margin-left: 10px; cursor: pointer;}
#tips { font-size: 12px; color: #888; margin-top: 10px; padding-bottom: 20px;}
</style>
</head>
<body>
<div id="toolbar">
<div class="tools-left">
<label style="font-weight: bold;">🎯 目标图像: </label>
<select id="image-select" onchange="changeImage()"></select>
</div>
<div class="tools-center">
<button class="btn" onclick="prevModel()">⬅ 上一模型 (Left)</button>
<span id="model-counter" style="font-size: 18px; font-weight: bold; min-width: 80px;"></span>
<button class="btn" onclick="nextModel()">下一模型 (Right) ➡</button>
</div>
<div class="tools-right">
<div class="zoom-control">
<label>🔍 放大镜: <span id="zoom-val">3.0x</span></label>
<input type="range" id="zoom-slider" min="1.5" max="8.0" step="0.5" value="3" oninput="updateZoom()">
</div>
<div class="switch-container">
<input type="checkbox" id="blind-mode" onchange="toggleBlindMode()">
<label for="blind-mode" style="cursor: pointer;">开启盲测模式</label>
</div>
</div>
</div>
<div id="image-container">
<img id="display-img" src="" alt="SR Image">
<div id="model-badge">Model Name</div>
<div id="magnifier"></div>
</div>
<div id="tips">💡 快捷操作:使用键盘左右方向键 [←] [→] 原位切换模型,上下方向键 [↑] [↓] 切换图像。鼠标悬停图像显示局部放大镜。</div>
<script>
const data = {DATA_JSON};
let currentImgIdx = 0;
let currentModIdx = 0;
let isBlindMode = false;
// 放大镜配置
let zoomLevel = 3.0;
const imgEl = document.getElementById('display-img');
const glassEl = document.getElementById('magnifier');
function init() {
const select = document.getElementById('image-select');
data.images.forEach((img, i) => {
let opt = document.createElement('option');
opt.value = i;
opt.innerHTML = img;
select.appendChild(opt);
});
updateView();
}
function changeImage() { currentImgIdx = parseInt(document.getElementById('image-select').value); updateView(); }
function prevModel() { currentModIdx = (currentModIdx - 1 + data.models.length) % data.models.length; updateView(); }
function nextModel() { currentModIdx = (currentModIdx + 1) % data.models.length; updateView(); }
function toggleBlindMode() { isBlindMode = document.getElementById('blind-mode').checked; updateView(); }
function updateZoom() {
zoomLevel = parseFloat(document.getElementById('zoom-slider').value);
document.getElementById('zoom-val').innerText = zoomLevel.toFixed(1) + "x";
}
function updateView() {
const imgName = data.images[currentImgIdx];
const rawModelName = data.models[currentModIdx];
const imgPath = data.paths[imgName][currentModIdx];
imgEl.src = imgPath;
// ✨ 同步更新放大镜的背景底图
glassEl.style.backgroundImage = "url('" + imgPath + "')";
const badge = document.getElementById('model-badge');
if (isBlindMode) {
badge.innerHTML = "Model " + (currentModIdx + 1);
badge.style.color = "#ffeb3b";
} else {
let shortName = rawModelName.replace('growth_channels_16_RDB_3_LeakyReLU_Attn_pixel_', '');
badge.innerHTML = shortName;
badge.style.color = "#00ffcc";
}
document.getElementById('model-counter').innerHTML = (currentModIdx + 1) + " / " + data.models.length;
}
// ==========================================
// ✨ 放大镜交互核心逻辑
// ==========================================
imgEl.addEventListener("mousemove", moveMagnifier);
glassEl.addEventListener("mousemove", moveMagnifier);
imgEl.addEventListener("mouseenter", () => glassEl.style.display = "block");
imgEl.addEventListener("mouseleave", () => glassEl.style.display = "none");
function moveMagnifier(e) {
e.preventDefault();
// 获取鼠标在图像上的相对坐标
const rect = imgEl.getBoundingClientRect();
let x = e.clientX - rect.left;
let y = e.clientY - rect.top;
// 防止坐标溢出图像边界
if (x > imgEl.width) { x = imgEl.width;}
if (x < 0) { x = 0; }
if (y > imgEl.height) { y = imgEl.height; }
if (y < 0) { y = 0; }
// 调整放大镜窗口的位置 (跟随鼠标居中)
glassEl.style.left = (x - glassEl.offsetWidth / 2) + "px";
glassEl.style.top = (y - glassEl.offsetHeight / 2) + "px";
// 计算放大倍率并设置背景图片大小
glassEl.style.backgroundSize = (imgEl.width * zoomLevel) + "px " + (imgEl.height * zoomLevel) + "px";
// 核心:推算背景图像的偏移量,让放大的焦点完美对齐鼠标位置
let bgX = -(x * zoomLevel - glassEl.offsetWidth / 2);
let bgY = -(y * zoomLevel - glassEl.offsetHeight / 2);
glassEl.style.backgroundPosition = bgX + "px " + bgY + "px";
}
// 键盘事件监听
document.addEventListener('keydown', (e) => {
if (e.key === "ArrowLeft") { e.preventDefault(); prevModel(); }
else if (e.key === "ArrowRight") { e.preventDefault(); nextModel(); }
else if (e.key === "ArrowUp") { e.preventDefault(); currentImgIdx = Math.max(0, currentImgIdx - 1); document.getElementById('image-select').value = currentImgIdx; updateView(); }
else if (e.key === "ArrowDown") { e.preventDefault(); currentImgIdx = Math.min(data.images.length - 1, currentImgIdx + 1); document.getElementById('image-select').value = currentImgIdx; updateView(); }
});
window.onload = init;
</script>
</body>
</html>
"""
html_path = os.path.join(html_dir, "index.html")
with open(html_path, "w", encoding="utf-8") as f:
f.write(html_content.replace("{DATA_JSON}", json.dumps(data)))
print(f"✅ HTML 网页生成完毕!请双击查阅: {html_path}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Batch Test and Visualize SR Models")
# 填入你之前的微调权重文件夹路径
parser.add_argument('--weights-dir', type=str,
default='D:/super-resolution/ESPCN_RDB_Encoscope/output/x2/growth_channels_16_RDB_3_LeakyReLU_Attn_pixel_finetuned',
help='包含多个 .pth 模型的文件夹路径')
# 填入验证集低分辨率图路径
parser.add_argument('--input-dir', type=str,
default='D:/super-resolution/datasets/SurgiSR4K/data/images_crop/test/1920x1080p',
help='低分辨率 (LR) 测试图路径')
# 填入验证集高分辨率真实图路径
parser.add_argument('--hr-dir', type=str,
default='D:/super-resolution/datasets/SurgiSR4K/data/images_crop/test/3840x2160p',
help='高分辨率 (HR) 测试图路径 (可选,用于对比)')
# 将默认值改为 None,我们将在后面动态生成它
parser.add_argument('--output-dir', type=str,
default=None,
help='测试结果和对比图的输出根目录 (默认保存在 weights-dir 下的 Visual_Results 目录)')
parser.add_argument('--scale', type=int, default=2, help='放大倍数')
# 支持通过相对路径或文件名过滤 (例如 --target-images patient1/01.png)
parser.add_argument('--target-images', nargs='+', default=['a (33).jpg', 'a (27).jpg'],
help='指定要测试的图片名或相对路径。默认 all 为全部。')
args = parser.parse_args()
# ✨ 动态设定:如果未指定 output-dir,则自动在权重目录下创建 Visual_Results 文件夹
if args.output_dir is None:
args.output_dir = os.path.join(args.weights_dir, 'Visual_Results')
# 1. 环境设置
cudnn.benchmark = False
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 2. 搜集所有模型
weight_paths = sorted(glob.glob(os.path.join(args.weights_dir, '*.pth')))
if not weight_paths:
raise ValueError(f"在 {args.weights_dir} 下没有找到任何 .pth 文件!")
# ==========================================
# 递归扫描输入目录下的所有图片
# ==========================================
all_image_paths = []
valid_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff'}
for p in Path(args.input_dir).rglob('*'):
if p.is_file() and p.suffix.lower() in valid_extensions:
all_image_paths.append(str(p))
all_image_paths = sorted(all_image_paths)
# 过滤指定的图像(支持按 basename 或 rel_path 过滤)
if 'all' not in args.target_images:
image_paths = []
for p in all_image_paths:
basename = os.path.basename(p)
rel_path = os.path.relpath(p, args.input_dir).replace('\', '/')
if basename in args.target_images or rel_path in args.target_images:
image_paths.append(p)
else:
image_paths = all_image_paths
if not image_paths:
raise ValueError(f"未找到指定的测试图像,请检查 input-dir 路径。")
print(f"==> 共找到 {len(weight_paths)} 个待测模型。")
print(f"==> 共指定 {len(image_paths)} 张测试图像。")
models_list = []
bicubic_cache = {}
# ===================================================================
# 阶段一:遍历所有模型,进行推理
# ===================================================================
for w_idx, weight_path in enumerate(weight_paths, 1):
model_name = os.path.splitext(os.path.basename(weight_path))[0]
models_list.append(model_name)
model_out_dir = os.path.join(args.output_dir, model_name)
print(f"\n[{w_idx}/{len(weight_paths)}] 🚀 正在加载并测试模型: {model_name}")
parsed_kwargs = parse_model_kwargs(model_name)
model = ESPCN_RDB(scale_factor=args.scale, num_channels=1, **parsed_kwargs).to(device)
checkpoint = torch.load(weight_path, map_location=device)
state_dict = checkpoint["model_state_dict"] if "model_state_dict" in checkpoint else checkpoint
clean_state_dict = {k: v for k, v in state_dict.items() if
not (k.endswith('total_ops') or k.endswith('total_params'))}
model.load_state_dict(clean_state_dict, strict=True)
model.eval()
for img_path in image_paths:
rel_path = os.path.relpath(img_path, args.input_dir)
save_path = os.path.join(model_out_dir, rel_path)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
if os.path.exists(save_path):
if rel_path not in bicubic_cache:
lr_img = pil_image.open(img_path).convert('RGB')
bicubic_cache[rel_path] = lr_img.resize((lr_img.width * args.scale, lr_img.height * args.scale),
resample=pil_image.BICUBIC)
continue
lr_img = pil_image.open(img_path).convert('RGB')
bicubic = lr_img.resize((lr_img.width * args.scale, lr_img.height * args.scale), resample=pil_image.BICUBIC)
if rel_path not in bicubic_cache:
bicubic_cache[rel_path] = bicubic
lr_tensor, _ = preprocess(lr_img, device)
_, ycbcr = preprocess(bicubic, device)
with torch.no_grad():
preds = model(lr_tensor).clamp(0.0, 1.0)
preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
output_img = pil_image.fromarray(output)
output_img.save(save_path)
del model
torch.cuda.empty_cache()
# ===================================================================
# 阶段二:聚合生成对比视图和 HTML
# ===================================================================
print("\n" + "=" * 50)
print("✨ 开始生成全目录的可视化对比图...")
print("=" * 50)
rel_paths_list = []
for img_path in image_paths:
rel_path = os.path.relpath(img_path, args.input_dir)
rel_paths_list.append(rel_path)
hr_path = os.path.join(args.hr_dir, rel_path) if args.hr_dir else None
# 1. 绘制完整大图的对比
create_advanced_comparison(
rel_path=rel_path,
models_list=models_list,
output_root=args.output_dir,
hr_path=hr_path,
bicubic_img=bicubic_cache.get(rel_path),
crop_box=None
)
# 2. 提取重点区域放大对比(默认在 x=1500, y=1000 大小400x400,可自行修改)
# create_advanced_comparison(
# rel_path=rel_path,
# models_list=models_list,
# output_root=args.output_dir,
# hr_path=hr_path,
# bicubic_img=bicubic_cache.get(rel_path),
# crop_box=(1500, 1000, 400, 400)
# )
# 调用 HTML 生成器
create_html_viewer(rel_paths_list, models_list, args.output_dir, args.hr_dir, bicubic_cache)