# -*- coding: utf-8 -*-
# @Time :
# @Author :
import cv2
import numpy as np

sift = cv2.SIFT_create()


def compuerSift2GetPts(img1, img2):
    # sift 查找关键点,关键点 And 描述
    kp1, des1 = sift.detectAndCompute(img1, None)
    kp2, des2 = sift.detectAndCompute(img2, None)
    
    matcher = cv2.BFMatcher()
    raw_matches = matcher.knnMatch(des1, des2, k=2)
    good_matches = []
    ratio = 0.75
    for m1, m2 in raw_matches:
        # 如果最接近和次接近的比值大于一个既定的值,那么我们保留这个最接近的值,认为它和其匹配的点为good_match
        if m1.distance < ratio * m2.distance:
            good_matches.append([m1])
    matches = cv2.drawMatchesKnn(img1, kp1, img2, kp2, good_matches, None, flags=2)
    ptsA = np.float32([kp1[m[0].queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
    ptsB = np.float32([kp2[m[0].trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)
    
    ransacReprojThreshold = 4
    #  单应性矩阵可以将一张图通过旋转、变换等方式与另一张图对齐
    # print(len(ptsA), len(ptsB))
    if len(ptsA) == 0: return ptsA, ptsB, 0
    H, status = cv2.findHomography(ptsA, ptsB, cv2.RANSAC, ransacReprojThreshold)
    cv2.imshow("matcher", matches)
    cv2.waitKey(100)
    
    return ptsA, ptsB, 1


def findBestDistanceAndPts(ptsA, ptsB):
    x_dct = {}
    y_dct = {}
    best_x, best_y = int(ptsA[0][0][0] - ptsB[0][0][0]), int(ptsA[0][0][1] - ptsB[0][0][1])
    x_cnt, y_cnt = 0, 0
    for i in range(len(ptsA)):
        # print(ptsA[i], '        ', ptsB[i])
        x_dis = int(ptsA[i][0][0] - ptsB[i][0][0])
        y_dis = int(ptsA[i][0][1] - ptsB[i][0][1])
        # print(x_dis)
        if x_dis in x_dct:
            x_dct.update({x_dis: int(x_dct.get(x_dis) + 1)})
            if x_dct.get(x_dis) > x_cnt:
                best_x = x_dis
                x_cnt = x_dct.get(x_dis)
            # print(x_dct.get(x_dis))
        else:
            x_dct.update({x_dis: 1})
            # print(x_dct.get(x_dis))
        # print(y_dis)
        if y_dis in y_dct:
            y_dct.update({y_dis: int(y_dct.get(y_dis) + 1)})
            if y_dct.get(y_dis) > y_cnt:
                best_y = y_dis
                y_cnt = y_dct.get(y_dis)
            # print(y_dct.get(y_dis))
        else:
            y_dct.update({y_dis: 1})
            # print(y_dct.get(y_dis))
    print(best_x, best_y)
    
    pt = []
    ptb = []
    for i in range(len(ptsA)):
        x_dis = int(ptsA[i][0][0] - ptsB[i][0][0])
        y_dis = int(ptsA[i][0][1] - ptsB[i][0][1])
        if abs(best_x - x_dis) <= 0:
            pt.append([ptsA[i][0][0], ptsA[i][0][1]])
    # print(pt)
    return pt, best_x, best_y


def minDistanceHasXy(ptsA, ptsB):
    dct = {}
    cnt = 0
    best = 's'
    for i in range(len(ptsA)):
        disx = int(ptsA[i][0][0] - ptsB[i][0][0] + 0.5)
        disy = int(ptsA[i][0][1] - ptsB[i][0][1] + 0.5)
        s = str(disx) + ',' + str(disy)
        # print(s)
        if s in dct:
            dct.updata({s: int(dct.get(s) + 1)})
            if dct.get(s) >= cnt:
                cnt = dct.get(s)
                best = s
                print(s)
        else:
            dct.update({s: int(1)})
    for i, j in dct.items():
        print(i, j)
    print(best)


def detectImg(img1, img2, pta, best_x, best_y):
    # print(pta)
    min_x = int(min(x[0] for x in pta))
    max_x = int(max(x[0] for x in pta))
    min_y = int(min(x[1] for x in pta))
    max_y = int(max(x[1] for x in pta))
    # print(min_x, max_x)
    # print(min_x - best_x, max_x - best_x)
    # print(min_y, max_y)
    # print(min_y - best_y, max_y - best_y)
    newimg1 = img1[min_y: max_y, min_x: max_x]
    newimg2 = img2[min_y - best_y: max_y - best_y, min_x - best_x: max_x - best_x]
    # cv2.imshow("newimg1", newimg1)
    # cv2.imshow("newimg2", newimg2)
    # cv2.waitKey(0)
    return newimg1, newimg2


if __name__ == '__main__':
    j = 0
    for i in range(20, 4771, 1):
        print(i)
        path1 = './data/907dat/gray/camera1-' + str(i) + '.png'
        path2 = './data/907dat/color/camera0-' + str(i) + '.png'
        img1 = cv2.imread(path1)
        img2 = cv2.imread(path2)
        if (img1 is None or img2 is None): continue
        PtsA, PtsB, f = compuerSift2GetPts(img1, img2)
        if (f == 0): continue
        pt, best_x, best_y = findBestDistanceAndPts(PtsA, PtsB)
        newimg1, newimg2 = detectImg(img1, img2, pt, best_x, best_y)
        if newimg1.shape[0] < 10 or newimg1.shape[1] < 10: continue
        print(newimg1.shape, newimg2.shape)
        # newimg1 = cv2.resize(newimg1, (320, 240))
        # newimg2 = cv2.resize(newimg2, (320, 240))
        wirtePath1 = './result/dat_result_2/gray/camera1-' + str(j) + '.png'
        wirtePath2 = './result/dat_result_2/color/camera0-' + str(j) + '.png'
        if newimg1.shape[0] > 255 and newimg1.shape[1] > 255 and newimg1.shape == newimg2.shape:
            # cv2.imwrite(wirtePath1, newimg1)
            # cv2.imwrite(wirtePath2, newimg2)
            j += 1
            cv2.imshow("newimg1", newimg1)
            cv2.imshow("newimg2", newimg2)
            cv2.waitKey()
    print(j)
    pass