#!/usr/bin/env python # -*- coding: utf-8 -*- import time import argparse import cv2 import numpy as np from ultralytics import YOLO from skimage.metrics import structural_similarity as ssim # 添加YOLOv8模型初始化 yolo_model = YOLO("best.pt") # 可替换为yolov8s/m/l等 yolo_model.to('cuda') # 启用GPU加速 def calculate_en(img): """计算信息熵(处理灰度图)""" hist = cv2.calcHist([img], [0], None, [256], [0, 256]) hist = hist / hist.sum() return -np.sum(hist * np.log2(hist + 1e-10)) def calculate_sf(img): """计算空间频率(处理灰度图)""" rf = np.sqrt(np.mean(np.square(np.diff(img, axis=0)))) cf = np.sqrt(np.mean(np.square(np.diff(img, axis=1)))) return np.sqrt(rf ** 2 + cf ** 2) def calculate_mi(img1, img2): """计算互信息(处理灰度图)""" hist_2d = np.histogram2d(img1.ravel(), img2.ravel(), 256)[0] pxy = hist_2d / hist_2d.sum() px = np.sum(pxy, axis=1) py = np.sum(pxy, axis=0) return np.sum(pxy * np.log2(pxy / (px[:, None] * py[None, :] + 1e-10) + 1e-10)) def calculate_ssim(img1, img2): """计算SSIM(处理灰度图)""" return ssim(img1, img2, data_range=255) # 裁剪线性RGB对比度拉伸:(去掉2%百分位以下的数,去掉98%百分位以上的数,上下百分位数一般相同,并设置输出上下限) def truncated_linear_stretch(image, truncated_value=2, maxout=255, min_out=0): """ :param image: :param truncated_value: :param maxout: :param min_out: :return: """ def gray_process(gray, maxout=maxout, minout=min_out): truncated_down = np.percentile(gray, truncated_value) truncated_up = np.percentile(gray, 100 - truncated_value) gray_new = ((maxout - minout) / (truncated_up - truncated_down)) * gray gray_new[gray_new < minout] = minout gray_new[gray_new > maxout] = maxout return np.uint8(gray_new) (b, g, r) = cv2.split(image) b = gray_process(b) g = gray_process(g) r = gray_process(r) result = cv2.merge((b, g, r)) # 合并每一个通道 return result # RGB图片配准函数,采用白天的可见光与红外灰度图,计算两者Surf共同特征点,之间的仿射矩阵。 def Images_matching(img_base, img_target): """ :param img_base: :param img_target:匹配图像 :return: 返回仿射矩阵 """ start = time.time() orb = cv2.ORB_create() # 对可见光图像进行对比度拉伸 # img_base = truncated_linear_stretch(img_base) img_base = cv2.cvtColor(img_base, cv2.COLOR_BGR2GRAY) sift = cv2.SIFT_create() # 使用sift算子计算特征点和特征点周围的特征向量 st1 = time.time() kp1, des1 = sift.detectAndCompute(img_base, None) # 1136 1136, 64 kp2, des2 = sift.detectAndCompute(img_target, None) en1 = time.time() # print(en1 - st1, "特征提取") # 进行KNN特征匹配 # FLANN_INDEX_KDTREE = 0 # 建立FLANN匹配器的参数 # indexParams = dict(algorithm=FLANN_INDEX_KDTREE, trees=5) # 配置索引,密度树的数量为5 # searchParams = dict(checks=50) # 指定递归次数 # flann = cv2.FlannBasedMatcher(indexParams, searchParams) # 建立匹配器 # matches = flann.knnMatch(des1, des2, k=2) # 得出匹配的关键点 list: 1136 # FLANN_INDEX_KDTREE = 1 # index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5) # search_params = dict(checks=50) # flann = cv2.FlannBasedMatcher(index_params, search_params) # matches = flann.knnMatch(des1, des2, k=2) st2 = time.time() matcher = cv2.BFMatcher() matches = matcher.knnMatch(des1, des2, k=2) de2 = time.time() # print(de2 - st2, "特征匹配") good = [] # 提取优秀的特征点 for m, n in matches: if m.distance < 0.75 * n.distance: # 如果第一个邻近距离比第二个邻近距离的0.7倍小,则保留 good.append(m) # 134 src_pts = np.array([kp1[m.queryIdx].pt for m in good]) # 查询图像的特征描述子索引 # 134, 2 dst_pts = np.array([kp2[m.trainIdx].pt for m in good]) # 训练(模板)图像的特征描述子索引 if len(src_pts) <= 4: print("Not enough matches are found - {}/{}".format(len(good), 4)) return 0, None, 0 else: print(len(dst_pts), len(src_pts), "配准坐标点") H = cv2.findHomography(dst_pts, src_pts, cv2.RANSAC, 4) # 生成变换矩阵 H[0]: 3, 3 H[1]: 134, 1 end = time.time() times = end - start # print("配准时间", times) return 1, H[0], len(dst_pts) def fusions(img_vl, img_inf): """ :param img_vl: 原图像 :param img_inf: 红外图像 :return: """ img_YUV = cv2.cvtColor(img_vl, cv2.COLOR_BGR2YUV) # 如果输入是BGR,需转换 # img_YUV = cv2.cvtColor(img_vl, cv2.COLOR_RGB2YUV) y, u, v = cv2.split(img_YUV) # 分离通道,获取Y通道 Yf = y * 0.5 + img_inf * 0.5 Yf = Yf.astype(np.uint8) fusion = cv2.cvtColor(cv2.merge((Yf, u, v)), cv2.COLOR_YUV2RGB) return fusion def removeBlackBorder(gray): """ 移除缝合后的图像的多余黑边 输入: image:三维numpy矩阵,待处理图像 输出: 裁剪后的图像 """ threshold = 40 # 阈值 nrow = gray.shape[0] # 获取图片尺寸 ncol = gray.shape[1] rowc = gray[:, int(1 / 2 * nrow)] # 无法区分黑色区域超过一半的情况 colc = gray[int(1 / 2 * ncol), :] rowflag = np.argwhere(rowc > threshold) colflag = np.argwhere(colc > threshold) left, bottom, right, top = rowflag[0, 0], colflag[-1, 0], rowflag[-1, 0], colflag[0, 0] # cv2.imshow('name', gray[left:right, top:bottom]) # 效果展示 cv2.waitKey(1) return gray[left:right, top:bottom], left, right, top, bottom def main(matchimg_vi, matchimg_in): """ :param matchimg_vi: 可见光图像 :param matchimg_in: 红外图像 :return: 融合好的图像(带检测结果) """ try: orimg_vi = matchimg_vi orimg_in = matchimg_in h, w = orimg_vi.shape[:2] # 480 640 # (3, 3)//获取对应的配准坐标点 flag, H, dot = Images_matching(matchimg_vi, matchimg_in) if flag == 0: return 0, None, 0, 0.0, 0.0, 0.0, 0.0 else: # 配准处理 matched_ni = cv2.warpPerspective(orimg_in, H, (w, h)) matched_ni, left, right, top, bottom = removeBlackBorder(matched_ni) # 裁剪可见光图像 # fusion = fusions(orimg_vi[left:right, top:bottom], matched_ni) # 不裁剪可见光图像 fusion = fusions(orimg_vi, matched_ni) # 转换为灰度计算指标 fusion_gray = cv2.cvtColor(fusion, cv2.COLOR_RGB2GRAY) cropped_vi_gray = cv2.cvtColor(orimg_vi, cv2.COLOR_BGR2GRAY) matched_ni_gray = matched_ni # 红外图已经是灰度 # 计算指标 en = calculate_en(fusion_gray) sf = calculate_sf(fusion_gray) mi_visible = calculate_mi(fusion_gray, cropped_vi_gray) mi_infrared = calculate_mi(fusion_gray, matched_ni_gray) mi_total = mi_visible + mi_infrared # 添加SSIM容错处理 try: ssim_visible = calculate_ssim(fusion_gray, cropped_vi_gray) ssim_infrared = calculate_ssim(fusion_gray, matched_ni_gray) ssim_avg = (ssim_visible + ssim_infrared) / 2 except Exception as ssim_error: print(f"SSIM计算错误: {ssim_error}") ssim_avg = -1 # 用-1表示计算失败 # YOLOv8目标检测 results = yolo_model(fusion) # 输入融合后的图像 annotated_image = results[0].plot() # 绘制检测框 # 返回带检测结果的图像 return 1, annotated_image, dot, en, sf, mi_total, ssim_avg except Exception as e: print(f"Error in fusion/detection: {e}") return 0, None, 0, 0.0, 0.0, 0.0, 0.0 def parse_args(): # 输入可见光和红外图像路径 visible_image_path = "./test/visible/visibleI0195.jpg" # 可见光图片路径 infrared_image_path = "./test/infrared/infraredI0195.jpg" # 红外图片路径 # 输入可见光和红外视频路径 visible_video_path = "./test/visible.mp4" # 可见光视频路径 infrared_video_path = "./test/infrared.mp4" # 红外视频路径 """解析命令行参数""" parser = argparse.ArgumentParser(description='图像融合与目标检测') parser.add_argument('--mode', type=str, choices=['video', 'image'], default='image', help='输入模式:video(视频流) 或 image(静态图片)') # 区分摄像头或视频文件 parser.add_argument('--source', type=str, choices=['camera', 'file'], help='视频输入类型:camera(摄像头)或 file(视频文件)') # 视频模式参数 parser.add_argument('--video1', type=str, default=visible_video_path, help='可见光视频路径(仅在source=file时需要)') parser.add_argument('--video2', type=str, default=infrared_video_path, help='红外视频路径(仅在source=file时需要)') # 摄像头模式参数 parser.add_argument('--camera_id1', type=int, default=0, help='可见光摄像头ID(仅在source=camera时需要,默认0)') parser.add_argument('--camera_id2', type=int, default=1, help='红外摄像头ID(仅在source=camera时需要,默认1)') parser.add_argument('--output', type=str, default='output.mp4', help='输出视频路径(仅在video模式需要)') # 图片模式参数 parser.add_argument('--visible', type=str, default=visible_image_path, help='可见光图片路径(仅在image模式需要)') parser.add_argument('--infrared', type=str, default=infrared_image_path, help='红外图片路径(仅在image模式需要)') return parser.parse_args() if __name__ == '__main__': time_all = 0 dots = 0 i = 0 args = parse_args() if args.mode == 'video': if args.source == 'file': # ========== 视频流处理模式 ========== if not args.video1 or not args.video2: raise ValueError("视频模式需要指定 --video1 和 --video2 参数") capture = cv2.VideoCapture(args.video2) capture2 = cv2.VideoCapture(args.video1) elif args.source == 'camera': # ========== 摄像头处理模式 ========== capture = cv2.VideoCapture(args.camera_id1) capture2 = cv2.VideoCapture(args.camera_id2) else: raise ValueError("必须指定 --source 参数(camera 或 file)") # 公共视频处理逻辑 fps = capture.get(cv2.CAP_PROP_FPS) if args.source == 'file' else 30 fourcc = cv2.VideoWriter_fourcc(*'XVID') out = cv2.VideoWriter(args.output, fourcc, fps, (640, 480)) while True: ret1, frame_vi = capture.read() # 可见光帧 ret2, frame_ir = capture2.read() # 红外帧 if not ret1 or not ret2: break # 红外图像转灰度 frame_ir_gray = cv2.cvtColor(frame_ir, cv2.COLOR_BGR2GRAY) # 执行融合与检测 flag, fusion, _ = main(frame_vi, frame_ir_gray) if flag == 1: cv2.imshow("Fusion with YOLOv8 Detection", fusion) out.write(fusion) if cv2.waitKey(1) == ord('q'): break # 释放资源 capture.release() capture2.release() out.release() cv2.destroyAllWindows() elif args.mode == 'image': # ========= 图片处理模式 ========== if not args.infrared or not args.visible: raise ValueError("图片模式需要指定 --visible 和 --infrared 参数") # 读取图像 img_visible = cv2.imread(args.visible) img_infrared = cv2.imread(args.infrared) if img_visible is None or img_infrared is None: print("Error: 图片加载失败,请检查路径!") exit() # 转换为灰度图(红外图像处理) img_inf_gray = cv2.cvtColor(img_infrared, cv2.COLOR_BGR2GRAY) # 执行融合与检测 flag, fusion_result, dot, en, sf, mi, ssim_val = main(img_visible, img_inf_gray) if flag == 1: # 展示评价指标 print("\n======== 融合质量评价 ========") print(f"信息熵(EN): {en:.2f}") print(f"空间频率(SF): {sf:.2f}") print(f"互信息(MI): {mi:.2f}") # 条件显示SSIM if ssim_val >= 0: print(f"结构相似性(SSIM): {ssim_val:.4f}") else: print("结构相似性(SSIM): 计算失败(已跳过)") print(f"配准点数: {dot}") # 显示并保存结果 # cv2.imshow("Fusion with Detection", fusion_result) cv2.imwrite("output/fusion_result.jpg", fusion_result) # cv2.waitKey(0) # cv2.destroyAllWindows() else: print("融合失败!")