#!/usr/bin/env python
# -*- coding: utf-8 -*-

import time
import argparse

import cv2
import numpy as np

from ultralytics import YOLO
from skimage.metrics import structural_similarity as ssim

# 添加YOLOv8模型初始化
yolo_model = YOLO("best.pt")  # 可替换为yolov8s/m/l等
yolo_model.to('cuda')  # 启用GPU加速


def calculate_en(img):
    """计算信息熵(处理灰度图)"""
    hist = cv2.calcHist([img], [0], None, [256], [0, 256])
    hist = hist / hist.sum()
    return -np.sum(hist * np.log2(hist + 1e-10))


def calculate_sf(img):
    """计算空间频率(处理灰度图)"""
    rf = np.sqrt(np.mean(np.square(np.diff(img, axis=0))))
    cf = np.sqrt(np.mean(np.square(np.diff(img, axis=1))))
    return np.sqrt(rf ** 2 + cf ** 2)


def calculate_mi(img1, img2):
    """计算互信息(处理灰度图)"""
    hist_2d = np.histogram2d(img1.ravel(), img2.ravel(), 256)[0]
    pxy = hist_2d / hist_2d.sum()
    px = np.sum(pxy, axis=1)
    py = np.sum(pxy, axis=0)
    return np.sum(pxy * np.log2(pxy / (px[:, None] * py[None, :] + 1e-10) + 1e-10))


def calculate_ssim(img1, img2):
    """计算SSIM(处理灰度图)"""
    return ssim(img1, img2, data_range=255)


# 裁剪线性RGB对比度拉伸:(去掉2%百分位以下的数,去掉98%百分位以上的数,上下百分位数一般相同,并设置输出上下限)
def truncated_linear_stretch(image, truncated_value=2, maxout=255, min_out=0):
    """
    :param image:
    :param truncated_value:
    :param maxout:
    :param min_out:
    :return:
    """
    
    def gray_process(gray, maxout=maxout, minout=min_out):
        truncated_down = np.percentile(gray, truncated_value)
        truncated_up = np.percentile(gray, 100 - truncated_value)
        gray_new = ((maxout - minout) / (truncated_up - truncated_down)) * gray
        gray_new[gray_new < minout] = minout
        gray_new[gray_new > maxout] = maxout
        return np.uint8(gray_new)
    
    (b, g, r) = cv2.split(image)
    b = gray_process(b)
    g = gray_process(g)
    r = gray_process(r)
    result = cv2.merge((b, g, r))  # 合并每一个通道
    return result


# RGB图片配准函数,采用白天的可见光与红外灰度图,计算两者Surf共同特征点,之间的仿射矩阵。
def Images_matching(img_base, img_target):
    """
    :param img_base:
    :param img_target:匹配图像
    :return: 返回仿射矩阵
    """
    start = time.time()
    orb = cv2.ORB_create()
    
    # 对可见光图像进行对比度拉伸
    # img_base = truncated_linear_stretch(img_base)
    
    img_base = cv2.cvtColor(img_base, cv2.COLOR_BGR2GRAY)
    sift = cv2.SIFT_create()
    # 使用sift算子计算特征点和特征点周围的特征向量
    st1 = time.time()
    kp1, des1 = sift.detectAndCompute(img_base, None)  # 1136    1136, 64
    kp2, des2 = sift.detectAndCompute(img_target, None)
    en1 = time.time()
    
    # print(en1 - st1, "特征提取")
    
    # 进行KNN特征匹配
    # FLANN_INDEX_KDTREE = 0  # 建立FLANN匹配器的参数
    # indexParams = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)  # 配置索引,密度树的数量为5
    # searchParams = dict(checks=50)  # 指定递归次数
    # flann = cv2.FlannBasedMatcher(indexParams, searchParams)  # 建立匹配器
    # matches = flann.knnMatch(des1, des2, k=2)  # 得出匹配的关键点  list: 1136
    # FLANN_INDEX_KDTREE = 1
    # index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
    # search_params = dict(checks=50)
    # flann = cv2.FlannBasedMatcher(index_params, search_params)
    # matches = flann.knnMatch(des1, des2, k=2)
    
    st2 = time.time()
    matcher = cv2.BFMatcher()
    matches = matcher.knnMatch(des1, des2, k=2)
    de2 = time.time()
    # print(de2 - st2, "特征匹配")
    good = []
    # 提取优秀的特征点
    for m, n in matches:
        if m.distance < 0.75 * n.distance:  # 如果第一个邻近距离比第二个邻近距离的0.7倍小,则保留
            good.append(m)  # 134
    src_pts = np.array([kp1[m.queryIdx].pt for m in good])  # 查询图像的特征描述子索引  # 134, 2
    dst_pts = np.array([kp2[m.trainIdx].pt for m in good])  # 训练(模板)图像的特征描述子索引
    if len(src_pts) <= 4:
        print("Not enough matches are found - {}/{}".format(len(good), 4))
        return 0, None, 0
    else:
        print(len(dst_pts), len(src_pts), "配准坐标点")
        H = cv2.findHomography(dst_pts, src_pts, cv2.RANSAC, 4)  # 生成变换矩阵  H[0]: 3, 3  H[1]: 134, 1
        end = time.time()
        times = end - start
        # print("配准时间", times)
        return 1, H[0], len(dst_pts)


def fusions(img_vl, img_inf):
    """
    :param img_vl: 原图像
    :param img_inf: 红外图像
    :return:
    """
    img_YUV = cv2.cvtColor(img_vl, cv2.COLOR_BGR2YUV)  # 如果输入是BGR,需转换
    # img_YUV = cv2.cvtColor(img_vl, cv2.COLOR_RGB2YUV)
    y, u, v = cv2.split(img_YUV)  # 分离通道,获取Y通道
    Yf = y * 0.5 + img_inf * 0.5
    Yf = Yf.astype(np.uint8)
    fusion = cv2.cvtColor(cv2.merge((Yf, u, v)), cv2.COLOR_YUV2RGB)
    return fusion


def removeBlackBorder(gray):
    """
    移除缝合后的图像的多余黑边
    输入:
        image:三维numpy矩阵,待处理图像
    输出:
        裁剪后的图像
    """
    threshold = 40  # 阈值
    nrow = gray.shape[0]  # 获取图片尺寸
    ncol = gray.shape[1]
    rowc = gray[:, int(1 / 2 * nrow)]  # 无法区分黑色区域超过一半的情况
    colc = gray[int(1 / 2 * ncol), :]
    rowflag = np.argwhere(rowc > threshold)
    colflag = np.argwhere(colc > threshold)
    left, bottom, right, top = rowflag[0, 0], colflag[-1, 0], rowflag[-1, 0], colflag[0, 0]
    # cv2.imshow('name', gray[left:right, top:bottom])  # 效果展示
    cv2.waitKey(1)
    return gray[left:right, top:bottom], left, right, top, bottom


def main(matchimg_vi, matchimg_in):
    """
    :param matchimg_vi: 可见光图像
    :param matchimg_in: 红外图像
    :return: 融合好的图像(带检测结果)
    """
    try:
        orimg_vi = matchimg_vi
        orimg_in = matchimg_in
        h, w = orimg_vi.shape[:2]  # 480 640
        # (3, 3)//获取对应的配准坐标点
        flag, H, dot = Images_matching(matchimg_vi, matchimg_in)
        if flag == 0:
            return 0, None, 0, 0.0, 0.0, 0.0, 0.0
        else:
            # 配准处理
            matched_ni = cv2.warpPerspective(orimg_in, H, (w, h))
            matched_ni, left, right, top, bottom = removeBlackBorder(matched_ni)
            
            # 裁剪可见光图像
            # fusion = fusions(orimg_vi[left:right, top:bottom], matched_ni)
            
            # 不裁剪可见光图像
            fusion = fusions(orimg_vi, matched_ni)
            
            # 转换为灰度计算指标
            fusion_gray = cv2.cvtColor(fusion, cv2.COLOR_RGB2GRAY)
            cropped_vi_gray = cv2.cvtColor(orimg_vi, cv2.COLOR_BGR2GRAY)
            matched_ni_gray = matched_ni  # 红外图已经是灰度
            
            # 计算指标
            en = calculate_en(fusion_gray)
            sf = calculate_sf(fusion_gray)
            mi_visible = calculate_mi(fusion_gray, cropped_vi_gray)
            mi_infrared = calculate_mi(fusion_gray, matched_ni_gray)
            mi_total = mi_visible + mi_infrared
            
            # 添加SSIM容错处理
            try:
                ssim_visible = calculate_ssim(fusion_gray, cropped_vi_gray)
                ssim_infrared = calculate_ssim(fusion_gray, matched_ni_gray)
                ssim_avg = (ssim_visible + ssim_infrared) / 2
            except Exception as ssim_error:
                print(f"SSIM计算错误: {ssim_error}")
                ssim_avg = -1  # 用-1表示计算失败
            
            # YOLOv8目标检测
            results = yolo_model(fusion)  # 输入融合后的图像
            annotated_image = results[0].plot()  # 绘制检测框
            
            # 返回带检测结果的图像
            return 1, annotated_image, dot, en, sf, mi_total, ssim_avg
    except Exception as e:
        print(f"Error in fusion/detection: {e}")
        return 0, None, 0, 0.0, 0.0, 0.0, 0.0


def parse_args():
    # 输入可见光和红外图像路径
    visible_image_path = "./test/visible/visibleI0195.jpg"  # 可见光图片路径
    infrared_image_path = "./test/infrared/infraredI0195.jpg"  # 红外图片路径
    # 输入可见光和红外视频路径
    visible_video_path = "./test/visible.mp4"  # 可见光视频路径
    infrared_video_path = "./test/infrared.mp4"  # 红外视频路径
    
    """解析命令行参数"""
    parser = argparse.ArgumentParser(description='图像融合与目标检测')
    
    parser.add_argument('--mode', type=str, choices=['video', 'image'], default='image',
                        help='输入模式:video(视频流) 或 image(静态图片)')
    
    # 区分摄像头或视频文件
    parser.add_argument('--source', type=str, choices=['camera', 'file'],
                        help='视频输入类型:camera(摄像头)或 file(视频文件)')
    
    # 视频模式参数
    parser.add_argument('--video1', type=str, default=visible_video_path,
                        help='可见光视频路径(仅在source=file时需要)')
    parser.add_argument('--video2', type=str, default=infrared_video_path,
                        help='红外视频路径(仅在source=file时需要)')
    
    # 摄像头模式参数
    parser.add_argument('--camera_id1', type=int, default=0,
                        help='可见光摄像头ID(仅在source=camera时需要,默认0)')
    parser.add_argument('--camera_id2', type=int, default=1,
                        help='红外摄像头ID(仅在source=camera时需要,默认1)')
    parser.add_argument('--output', type=str, default='output.mp4',
                        help='输出视频路径(仅在video模式需要)')
    
    # 图片模式参数
    parser.add_argument('--visible', type=str, default=visible_image_path,
                        help='可见光图片路径(仅在image模式需要)')
    parser.add_argument('--infrared', type=str, default=infrared_image_path,
                        help='红外图片路径(仅在image模式需要)')
    
    return parser.parse_args()


if __name__ == '__main__':
    time_all = 0
    dots = 0
    i = 0
    args = parse_args()
    
    if args.mode == 'video':
        if args.source == 'file':
            # ========== 视频流处理模式 ==========
            if not args.video1 or not args.video2:
                raise ValueError("视频模式需要指定 --video1 和 --video2 参数")
            capture = cv2.VideoCapture(args.video2)
            capture2 = cv2.VideoCapture(args.video1)
        elif args.source == 'camera':
            # ========== 摄像头处理模式 ==========
            capture = cv2.VideoCapture(args.camera_id1)
            capture2 = cv2.VideoCapture(args.camera_id2)
        else:
            raise ValueError("必须指定 --source 参数(camera 或 file)")
        
        # 公共视频处理逻辑
        fps = capture.get(cv2.CAP_PROP_FPS) if args.source == 'file' else 30
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        out = cv2.VideoWriter(args.output, fourcc, fps, (640, 480))
        
        while True:
            ret1, frame_vi = capture.read()  # 可见光帧
            ret2, frame_ir = capture2.read()  # 红外帧
            if not ret1 or not ret2:
                break
            
            # 红外图像转灰度
            frame_ir_gray = cv2.cvtColor(frame_ir, cv2.COLOR_BGR2GRAY)
            
            # 执行融合与检测
            flag, fusion, _ = main(frame_vi, frame_ir_gray)
            
            if flag == 1:
                cv2.imshow("Fusion with YOLOv8 Detection", fusion)
                out.write(fusion)
            
            if cv2.waitKey(1) == ord('q'):
                break
        
        # 释放资源
        capture.release()
        capture2.release()
        out.release()
        cv2.destroyAllWindows()
    
    elif args.mode == 'image':
        # ========= 图片处理模式 ==========
        if not args.infrared or not args.visible:
            raise ValueError("图片模式需要指定 --visible 和 --infrared 参数")
        
        # 读取图像
        img_visible = cv2.imread(args.visible)
        img_infrared = cv2.imread(args.infrared)
        
        if img_visible is None or img_infrared is None:
            print("Error: 图片加载失败,请检查路径!")
            exit()
        
        # 转换为灰度图(红外图像处理)
        img_inf_gray = cv2.cvtColor(img_infrared, cv2.COLOR_BGR2GRAY)
        
        # 执行融合与检测
        flag, fusion_result, dot, en, sf, mi, ssim_val = main(img_visible, img_inf_gray)
        
        if flag == 1:
            # 展示评价指标
            print("\n======== 融合质量评价 ========")
            print(f"信息熵(EN): {en:.2f}")
            print(f"空间频率(SF): {sf:.2f}")
            print(f"互信息(MI): {mi:.2f}")
            
            # 条件显示SSIM
            if ssim_val >= 0:
                print(f"结构相似性(SSIM): {ssim_val:.4f}")
            else:
                print("结构相似性(SSIM): 计算失败(已跳过)")
            
            print(f"配准点数: {dot}")
            # 显示并保存结果
            # cv2.imshow("Fusion with Detection", fusion_result)
            cv2.imwrite("output/fusion_result.jpg", fusion_result)
            # cv2.waitKey(0)
            # cv2.destroyAllWindows()
        else:
            print("融合失败!")