Graduation-Project/image_fusion/Image_Registration_test.py
2025-04-19 20:09:17 +08:00

278 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import time
import argparse
import cv2
import numpy as np
from ultralytics import YOLO
# 添加YOLOv8模型初始化
yolo_model = YOLO("yolov8n.pt") # 可替换为yolov8s/m/l等
yolo_model.to('cuda') # 启用GPU加速
# 裁剪线性RGB对比度拉伸去掉2%百分位以下的数去掉98%百分位以上的数,上下百分位数一般相同,并设置输出上下限)
def truncated_linear_stretch(image, truncated_value=2, maxout=255, min_out=0):
"""
:param image:
:param truncated_value:
:param maxout:
:param min_out:
:return:
"""
def gray_process(gray, maxout=maxout, minout=min_out):
truncated_down = np.percentile(gray, truncated_value)
truncated_up = np.percentile(gray, 100 - truncated_value)
gray_new = ((maxout - minout) / (truncated_up - truncated_down)) * gray
gray_new[gray_new < minout] = minout
gray_new[gray_new > maxout] = maxout
return np.uint8(gray_new)
(b, g, r) = cv2.split(image)
b = gray_process(b)
g = gray_process(g)
r = gray_process(r)
result = cv2.merge((b, g, r)) # 合并每一个通道
return result
# RGB图片配准函数采用白天的可见光与红外灰度图计算两者Surf共同特征点之间的仿射矩阵。
def Images_matching(img_base, img_target):
"""
:param img_base:
:param img_target:匹配图像
:return: 返回仿射矩阵
"""
start = time.time()
orb = cv2.ORB_create()
img_base = cv2.cvtColor(img_base, cv2.COLOR_BGR2GRAY)
sift = cv2.SIFT_create()
# 使用sift算子计算特征点和特征点周围的特征向量
st1 = time.time()
kp1, des1 = sift.detectAndCompute(img_base, None) # 1136 1136, 64
kp2, des2 = sift.detectAndCompute(img_target, None)
en1 = time.time()
# print(en1 - st1, "特征提取")
# 进行KNN特征匹配
# FLANN_INDEX_KDTREE = 0 # 建立FLANN匹配器的参数
# indexParams = dict(algorithm=FLANN_INDEX_KDTREE, trees=5) # 配置索引密度树的数量为5
# searchParams = dict(checks=50) # 指定递归次数
# flann = cv2.FlannBasedMatcher(indexParams, searchParams) # 建立匹配器
# matches = flann.knnMatch(des1, des2, k=2) # 得出匹配的关键点 list: 1136
# FLANN_INDEX_KDTREE = 1
# index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
# search_params = dict(checks=50)
# flann = cv2.FlannBasedMatcher(index_params, search_params)
# matches = flann.knnMatch(des1, des2, k=2)
st2 = time.time()
matcher = cv2.BFMatcher()
matches = matcher.knnMatch(des1, des2, k=2)
de2 = time.time()
# print(de2 - st2, "特征匹配")
good = []
# 提取优秀的特征点
for m, n in matches:
if m.distance < 0.75 * n.distance: # 如果第一个邻近距离比第二个邻近距离的0.7倍小,则保留
good.append(m) # 134
src_pts = np.array([kp1[m.queryIdx].pt for m in good]) # 查询图像的特征描述子索引 # 134, 2
dst_pts = np.array([kp2[m.trainIdx].pt for m in good]) # 训练(模板)图像的特征描述子索引
if len(src_pts) <= 4:
print("Not enough matches are found - {}/{}".format(len(good), 4))
return 0, None, 0
else:
print(len(dst_pts), len(src_pts), "配准坐标点")
H = cv2.findHomography(dst_pts, src_pts, cv2.RANSAC, 4) # 生成变换矩阵 H[0]: 3, 3 H[1]: 134, 1
end = time.time()
times = end - start
# print("配准时间", times)
return 1, H[0], len(dst_pts)
def fusions(img_vl, img_inf):
"""
:param img_vl: 原图像
:param img_inf: 红外图像
:return:
"""
img_YUV = cv2.cvtColor(img_vl, cv2.COLOR_BGR2YUV) # 如果输入是BGR需转换
# img_YUV = cv2.cvtColor(img_vl, cv2.COLOR_RGB2YUV)
y, u, v = cv2.split(img_YUV) # 分离通道,获取Y通道
Yf = y * 0.5 + img_inf * 0.5
Yf = Yf.astype(np.uint8)
fusion = cv2.cvtColor(cv2.merge((Yf, u, v)), cv2.COLOR_YUV2RGB)
return fusion
def removeBlackBorder(gray):
"""
移除缝合后的图像的多余黑边
输入:
image三维numpy矩阵待处理图像
输出:
裁剪后的图像
"""
threshold = 40 # 阈值
nrow = gray.shape[0] # 获取图片尺寸
ncol = gray.shape[1]
rowc = gray[:, int(1 / 2 * nrow)] # 无法区分黑色区域超过一半的情况
colc = gray[int(1 / 2 * ncol), :]
rowflag = np.argwhere(rowc > threshold)
colflag = np.argwhere(colc > threshold)
left, bottom, right, top = rowflag[0, 0], colflag[-1, 0], rowflag[-1, 0], colflag[0, 0]
# cv2.imshow('name', gray[left:right, top:bottom]) # 效果展示
cv2.waitKey(1)
return gray[left:right, top:bottom], left, right, top, bottom
def main(matchimg_vi, matchimg_in):
"""
:param matchimg_vi: 可见光图像
:param matchimg_in: 红外图像
:return: 融合好的图像(带检测结果)
"""
try:
orimg_vi = matchimg_vi
orimg_in = matchimg_in
h, w = orimg_vi.shape[:2] # 480 640
flag, H, dot = Images_matching(matchimg_vi, matchimg_in) # (3, 3)//获取对应的配准坐标点
if flag == 0:
return 0, None, 0
else:
matched_ni = cv2.warpPerspective(orimg_in, H, (w, h))
# matched_ni,left,right,top,bottom=removeBlackBorder(matched_ni)
# fusion = fusions(orimg_vi[left:right, top:bottom], matched_ni)
fusion = fusions(orimg_vi, matched_ni)
# YOLOv8目标检测
results = yolo_model(fusion) # 输入融合后的图像
annotated_image = results[0].plot() # 绘制检测框
return 1, annotated_image, dot # 返回带检测结果的图像
except Exception as e:
print(f"Error in fusion/detection: {e}")
return 0, None, 0
def parse_args():
# 输入可见光和红外图像路径
visible_image_path = "../test/visible.jpg" # 可见光图片路径
infrared_image_path = "../test/infrared.jpg" # 红外图片路径
# 输入可见光和红外视频路径
visible_video_path = "../test/visible.mp4" # 可见光视频路径
infrared_video_path = "../test/infrared.mp4" # 红外视频路径
"""解析命令行参数"""
parser = argparse.ArgumentParser(description='图像融合与目标检测')
parser.add_argument('--mode', type=str, choices=['video', 'image'], default='image',
help='输入模式video视频流 或 image静态图片')
# 区分摄像头或视频文件
parser.add_argument('--source', type=str, choices=['camera', 'file'],
help='视频输入类型camera摄像头或 file视频文件')
# 视频模式参数
parser.add_argument('--video1', type=str, default=visible_video_path,
help='可见光视频路径仅在source=file时需要')
parser.add_argument('--video2', type=str, default=infrared_video_path,
help='红外视频路径仅在source=file时需要')
# 摄像头模式参数
parser.add_argument('--camera_id1', type=int, default=0,
help='可见光摄像头ID仅在source=camera时需要默认0')
parser.add_argument('--camera_id2', type=int, default=1,
help='红外摄像头ID仅在source=camera时需要默认1')
parser.add_argument('--output', type=str, default='output.mp4',
help='输出视频路径仅在video模式需要')
# 图片模式参数
parser.add_argument('--visible', type=str, default=visible_image_path,
help='可见光图片路径仅在image模式需要')
parser.add_argument('--infrared', type=str, default=infrared_image_path,
help='红外图片路径仅在image模式需要')
return parser.parse_args()
if __name__ == '__main__':
time_all = 0
dots = 0
i = 0
args = parse_args()
if args.mode == 'video':
if args.source == 'file':
# ========== 视频流处理模式 ==========
if not args.video1 or not args.video2:
raise ValueError("视频模式需要指定 --video1 和 --video2 参数")
capture = cv2.VideoCapture(args.video2)
capture2 = cv2.VideoCapture(args.video1)
elif args.source == 'camera':
# ========== 摄像头处理模式 ==========
capture = cv2.VideoCapture(args.camera_id1)
capture2 = cv2.VideoCapture(args.camera_id2)
else:
raise ValueError("必须指定 --source 参数camera 或 file")
# 公共视频处理逻辑
fps = capture.get(cv2.CAP_PROP_FPS) if args.source == 'file' else 30
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(args.output, fourcc, fps, (640, 480))
while True:
ret1, frame_vi = capture.read() # 可见光帧
ret2, frame_ir = capture2.read() # 红外帧
if not ret1 or not ret2:
break
# 红外图像转灰度
frame_ir_gray = cv2.cvtColor(frame_ir, cv2.COLOR_BGR2GRAY)
# 执行融合与检测
flag, fusion, _ = main(frame_vi, frame_ir_gray)
if flag == 1:
cv2.imshow("Fusion with YOLOv8 Detection", fusion)
out.write(fusion)
if cv2.waitKey(1) == ord('q'):
break
# 释放资源
capture.release()
capture2.release()
out.release()
cv2.destroyAllWindows()
elif args.mode == 'image':
# ========= 图片处理模式 ==========
if not args.infrared or not args.visible:
raise ValueError("图片模式需要指定 --visible 和 --infrared 参数")
# 读取图像
img_visible = cv2.imread(args.visible)
img_infrared = cv2.imread(args.infrared)
if img_visible is None or img_infrared is None:
print("Error: 图片加载失败,请检查路径!")
exit()
# 转换为灰度图(红外图像处理)
img_inf_gray = cv2.cvtColor(img_infrared, cv2.COLOR_BGR2GRAY)
# 执行融合与检测
flag, fusion_result, _ = main(img_visible, img_inf_gray)
if flag == 1:
# 显示并保存结果
cv2.imshow("Fusion with Detection", fusion_result)
cv2.imwrite("../output/fusion_result.jpg", fusion_result)
cv2.waitKey(0)
cv2.destroyAllWindows()
else:
print("融合失败!")