评价指标优化

This commit is contained in:
myh 2025-04-22 21:41:58 +08:00
parent 89d8f4c0df
commit ba4508507b

View File

@ -11,7 +11,7 @@ from ultralytics import YOLO
from skimage.metrics import structural_similarity as ssim from skimage.metrics import structural_similarity as ssim
# 添加YOLOv8模型初始化 # 添加YOLOv8模型初始化
yolo_model = YOLO("yolov8n.pt") # 可替换为yolov8s/m/l等 yolo_model = YOLO("best.pt") # 可替换为yolov8s/m/l等
yolo_model.to('cuda') # 启用GPU加速 yolo_model.to('cuda') # 启用GPU加速
@ -177,21 +177,21 @@ def main(matchimg_vi, matchimg_in):
# (3, 3)//获取对应的配准坐标点 # (3, 3)//获取对应的配准坐标点
flag, H, dot = Images_matching(matchimg_vi, matchimg_in) flag, H, dot = Images_matching(matchimg_vi, matchimg_in)
if flag == 0: if flag == 0:
return 0, None, 0 return 0, None, 0, 0.0, 0.0, 0.0, 0.0
else: else:
# 配准处理 # 配准处理
matched_ni = cv2.warpPerspective(orimg_in, H, (w, h)) matched_ni = cv2.warpPerspective(orimg_in, H, (w, h))
matched_ni, left, right, top, bottom = removeBlackBorder(matched_ni) matched_ni, left, right, top, bottom = removeBlackBorder(matched_ni)
# 裁剪可见光图像 # 裁剪可见光图像
cropped_vi = orimg_vi[left:right, top:bottom]
# fusion = fusions(orimg_vi[left:right, top:bottom], matched_ni) # fusion = fusions(orimg_vi[left:right, top:bottom], matched_ni)
fusion = fusions(cropped_vi, matched_ni)
# 不裁剪可见光图像
fusion = fusions(orimg_vi, matched_ni)
# 转换为灰度计算指标 # 转换为灰度计算指标
fusion_gray = cv2.cvtColor(fusion, cv2.COLOR_RGB2GRAY) fusion_gray = cv2.cvtColor(fusion, cv2.COLOR_RGB2GRAY)
cropped_vi_gray = cv2.cvtColor(cropped_vi, cv2.COLOR_BGR2GRAY) cropped_vi_gray = cv2.cvtColor(orimg_vi, cv2.COLOR_BGR2GRAY)
matched_ni_gray = matched_ni # 红外图已经是灰度 matched_ni_gray = matched_ni # 红外图已经是灰度
# 计算指标 # 计算指标
@ -200,9 +200,15 @@ def main(matchimg_vi, matchimg_in):
mi_visible = calculate_mi(fusion_gray, cropped_vi_gray) mi_visible = calculate_mi(fusion_gray, cropped_vi_gray)
mi_infrared = calculate_mi(fusion_gray, matched_ni_gray) mi_infrared = calculate_mi(fusion_gray, matched_ni_gray)
mi_total = mi_visible + mi_infrared mi_total = mi_visible + mi_infrared
# 添加SSIM容错处理
try:
ssim_visible = calculate_ssim(fusion_gray, cropped_vi_gray) ssim_visible = calculate_ssim(fusion_gray, cropped_vi_gray)
ssim_infrared = calculate_ssim(fusion_gray, matched_ni_gray) ssim_infrared = calculate_ssim(fusion_gray, matched_ni_gray)
ssim_avg = (ssim_visible + ssim_infrared) / 2 ssim_avg = (ssim_visible + ssim_infrared) / 2
except Exception as ssim_error:
print(f"SSIM计算错误: {ssim_error}")
ssim_avg = -1 # 用-1表示计算失败
# YOLOv8目标检测 # YOLOv8目标检测
results = yolo_model(fusion) # 输入融合后的图像 results = yolo_model(fusion) # 输入融合后的图像
@ -212,16 +218,16 @@ def main(matchimg_vi, matchimg_in):
return 1, annotated_image, dot, en, sf, mi_total, ssim_avg return 1, annotated_image, dot, en, sf, mi_total, ssim_avg
except Exception as e: except Exception as e:
print(f"Error in fusion/detection: {e}") print(f"Error in fusion/detection: {e}")
return 0, None, 0 return 0, None, 0, 0.0, 0.0, 0.0, 0.0
def parse_args(): def parse_args():
# 输入可见光和红外图像路径 # 输入可见光和红外图像路径
visible_image_path = "test/visible.jpg" # 可见光图片路径 visible_image_path = "./test/visible/visibleI0195.jpg" # 可见光图片路径
infrared_image_path = "test/infrared.jpg" # 红外图片路径 infrared_image_path = "./test/infrared/infraredI0195.jpg" # 红外图片路径
# 输入可见光和红外视频路径 # 输入可见光和红外视频路径
visible_video_path = "test/visible.mp4" # 可见光视频路径 visible_video_path = "./test/visible.mp4" # 可见光视频路径
infrared_video_path = "test/infrared.mp4" # 红外视频路径 infrared_video_path = "./test/infrared.mp4" # 红外视频路径
"""解析命令行参数""" """解析命令行参数"""
parser = argparse.ArgumentParser(description='图像融合与目标检测') parser = argparse.ArgumentParser(description='图像融合与目标检测')
@ -331,12 +337,18 @@ if __name__ == '__main__':
print(f"信息熵EN: {en:.2f}") print(f"信息熵EN: {en:.2f}")
print(f"空间频率SF: {sf:.2f}") print(f"空间频率SF: {sf:.2f}")
print(f"互信息MI: {mi:.2f}") print(f"互信息MI: {mi:.2f}")
print(f"结构相似性SSIM: {ssim_val:.4f}")
# 条件显示SSIM
if ssim_val >= 0:
print(f"结构相似性SSIM: {ssim_val:.4f}")
else:
print("结构相似性SSIM: 计算失败(已跳过)")
print(f"配准点数: {dot}")
# 显示并保存结果 # 显示并保存结果
cv2.imshow("Fusion with Detection", fusion_result) # cv2.imshow("Fusion with Detection", fusion_result)
cv2.imwrite("output/fusion_result.jpg", fusion_result) cv2.imwrite("output/fusion_result.jpg", fusion_result)
cv2.waitKey(0) # cv2.waitKey(0)
cv2.destroyAllWindows() # cv2.destroyAllWindows()
else: else:
print("融合失败!") print("融合失败!")