添加评价指标

This commit is contained in:
myh 2025-04-22 16:35:29 +08:00
parent d1ed958db5
commit 89d8f4c0df

View File

@ -8,12 +8,41 @@ import cv2
import numpy as np
from ultralytics import YOLO
from skimage.metrics import structural_similarity as ssim
# 添加YOLOv8模型初始化
yolo_model = YOLO("yolov8n.pt") # 可替换为yolov8s/m/l等
yolo_model.to('cuda') # 启用GPU加速
def calculate_en(img):
"""计算信息熵(处理灰度图)"""
hist = cv2.calcHist([img], [0], None, [256], [0, 256])
hist = hist / hist.sum()
return -np.sum(hist * np.log2(hist + 1e-10))
def calculate_sf(img):
"""计算空间频率(处理灰度图)"""
rf = np.sqrt(np.mean(np.square(np.diff(img, axis=0))))
cf = np.sqrt(np.mean(np.square(np.diff(img, axis=1))))
return np.sqrt(rf ** 2 + cf ** 2)
def calculate_mi(img1, img2):
"""计算互信息(处理灰度图)"""
hist_2d = np.histogram2d(img1.ravel(), img2.ravel(), 256)[0]
pxy = hist_2d / hist_2d.sum()
px = np.sum(pxy, axis=1)
py = np.sum(pxy, axis=0)
return np.sum(pxy * np.log2(pxy / (px[:, None] * py[None, :] + 1e-10) + 1e-10))
def calculate_ssim(img1, img2):
"""计算SSIM处理灰度图"""
return ssim(img1, img2, data_range=255)
# 裁剪线性RGB对比度拉伸去掉2%百分位以下的数去掉98%百分位以上的数,上下百分位数一般相同,并设置输出上下限)
def truncated_linear_stretch(image, truncated_value=2, maxout=255, min_out=0):
"""
@ -145,20 +174,42 @@ def main(matchimg_vi, matchimg_in):
orimg_vi = matchimg_vi
orimg_in = matchimg_in
h, w = orimg_vi.shape[:2] # 480 640
flag, H, dot = Images_matching(matchimg_vi, matchimg_in) # (3, 3)//获取对应的配准坐标点
# (3, 3)//获取对应的配准坐标点
flag, H, dot = Images_matching(matchimg_vi, matchimg_in)
if flag == 0:
return 0, None, 0
else:
# 配准处理
matched_ni = cv2.warpPerspective(orimg_in, H, (w, h))
matched_ni, left, right, top, bottom = removeBlackBorder(matched_ni)
# 裁剪可见光图像
cropped_vi = orimg_vi[left:right, top:bottom]
# fusion = fusions(orimg_vi[left:right, top:bottom], matched_ni)
fusion = fusions(orimg_vi, matched_ni)
fusion = fusions(cropped_vi, matched_ni)
# 转换为灰度计算指标
fusion_gray = cv2.cvtColor(fusion, cv2.COLOR_RGB2GRAY)
cropped_vi_gray = cv2.cvtColor(cropped_vi, cv2.COLOR_BGR2GRAY)
matched_ni_gray = matched_ni # 红外图已经是灰度
# 计算指标
en = calculate_en(fusion_gray)
sf = calculate_sf(fusion_gray)
mi_visible = calculate_mi(fusion_gray, cropped_vi_gray)
mi_infrared = calculate_mi(fusion_gray, matched_ni_gray)
mi_total = mi_visible + mi_infrared
ssim_visible = calculate_ssim(fusion_gray, cropped_vi_gray)
ssim_infrared = calculate_ssim(fusion_gray, matched_ni_gray)
ssim_avg = (ssim_visible + ssim_infrared) / 2
# YOLOv8目标检测
results = yolo_model(fusion) # 输入融合后的图像
annotated_image = results[0].plot() # 绘制检测框
return 1, annotated_image, dot # 返回带检测结果的图像
# 返回带检测结果的图像
return 1, annotated_image, dot, en, sf, mi_total, ssim_avg
except Exception as e:
print(f"Error in fusion/detection: {e}")
return 0, None, 0
@ -272,9 +323,16 @@ if __name__ == '__main__':
img_inf_gray = cv2.cvtColor(img_infrared, cv2.COLOR_BGR2GRAY)
# 执行融合与检测
flag, fusion_result, _ = main(img_visible, img_inf_gray)
flag, fusion_result, dot, en, sf, mi, ssim_val = main(img_visible, img_inf_gray)
if flag == 1:
# 展示评价指标
print("\n======== 融合质量评价 ========")
print(f"信息熵EN: {en:.2f}")
print(f"空间频率SF: {sf:.2f}")
print(f"互信息MI: {mi:.2f}")
print(f"结构相似性SSIM: {ssim_val:.4f}")
# 显示并保存结果
cv2.imshow("Fusion with Detection", fusion_result)
cv2.imwrite("output/fusion_result.jpg", fusion_result)