Compare commits
No commits in common. "960b66a69242918a7d38dd6e2f0f730adbf60466" and "1930e1b96bd175c055396c62d1c377c399c1c4a7" have entirely different histories.
960b66a692
...
1930e1b96b
@ -1,16 +0,0 @@
|
|||||||
# 创建测试目录结构
|
|
||||||
mkdir -p ./test_data/{client1,client2}/{train,val}/images
|
|
||||||
mkdir -p ./test_data/{client1,client2}/{train,val}/labels
|
|
||||||
|
|
||||||
# 生成虚拟数据(各客户端仅需2张图片)
|
|
||||||
for client in client1 client2; do
|
|
||||||
for split in train val; do
|
|
||||||
# 创建空图片(128x128 RGB)
|
|
||||||
magick -size 128x128 xc:white test_data/${client}/${split}/images/img1.jpg
|
|
||||||
magick -size 128x128 xc:black test_data/${client}/${split}/images/img2.jpg
|
|
||||||
|
|
||||||
# 创建示例标签文件
|
|
||||||
echo "0 0.5 0.5 0.2 0.2" > test_data/${client}/${split}/labels/img1.txt
|
|
||||||
echo "1 0.3 0.3 0.4 0.4" > test_data/${client}/${split}/labels/img2.txt
|
|
||||||
done
|
|
||||||
done
|
|
@ -1,4 +0,0 @@
|
|||||||
train: ../test_data/client1/train/images
|
|
||||||
val: ../test_data/client1/val/images
|
|
||||||
nc: 2
|
|
||||||
names: [ 'class0', 'class1' ]
|
|
@ -1,4 +0,0 @@
|
|||||||
train: ../test_data/client2/train/images
|
|
||||||
val: ../test_data/client2/val/images
|
|
||||||
nc: 2
|
|
||||||
names: [ 'class0', 'class1' ]
|
|
@ -3,9 +3,9 @@ import torch
|
|||||||
import os
|
import os
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch.optim import lr_scheduler
|
from torch.optim import lr_scheduler
|
||||||
from fed_example.utils.data_utils import get_data
|
from util.data_utils import get_data
|
||||||
from fed_example.utils.model_utils import get_model
|
from util.model_utils import get_model
|
||||||
from fed_example.utils.train_utils import train_model, validate_model, v3_update_model_weights
|
from util.train_utils import train_model, validate_model, update_model_weights, v3_update_model_weights
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
Before Width: | Height: | Size: 225 B |
Before Width: | Height: | Size: 225 B |
@ -1 +0,0 @@
|
|||||||
0 0.5 0.5 0.2 0.2
|
|
@ -1 +0,0 @@
|
|||||||
1 0.3 0.3 0.4 0.4
|
|
Before Width: | Height: | Size: 225 B |
Before Width: | Height: | Size: 225 B |
@ -1 +0,0 @@
|
|||||||
0 0.5 0.5 0.2 0.2
|
|
@ -1 +0,0 @@
|
|||||||
1 0.3 0.3 0.4 0.4
|
|
Before Width: | Height: | Size: 225 B |
Before Width: | Height: | Size: 225 B |
@ -1 +0,0 @@
|
|||||||
0 0.5 0.5 0.2 0.2
|
|
@ -1 +0,0 @@
|
|||||||
1 0.3 0.3 0.4 0.4
|
|
Before Width: | Height: | Size: 225 B |
Before Width: | Height: | Size: 225 B |
@ -1 +0,0 @@
|
|||||||
0 0.5 0.5 0.2 0.2
|
|
@ -1 +0,0 @@
|
|||||||
1 0.3 0.3 0.4 0.4
|
|
@ -1,12 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
from collections import Counter
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from sklearn.model_selection import train_test_split
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torchvision import transforms
|
||||||
from torch.utils.data import Dataset, random_split
|
from torch.utils.data import DataLoader, Dataset, random_split
|
||||||
|
from collections import Counter
|
||||||
|
from torch.utils.data import DataLoader, Subset
|
||||||
from torchvision import transforms, datasets
|
from torchvision import transforms, datasets
|
||||||
|
import os
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
|
||||||
class CustomImageDataset(Dataset):
|
class CustomImageDataset(Dataset):
|
||||||
@ -180,29 +181,6 @@ def get_Fourdata(train_path, val_path, batch_size, num_workers):
|
|||||||
return (*client_train_loaders, *client_val_loaders, global_val_loader)
|
return (*client_train_loaders, *client_val_loaders, global_val_loader)
|
||||||
|
|
||||||
|
|
||||||
def get_federated_data(train_path, val_path, num_clients=3, batch_size=16, num_workers=8):
|
|
||||||
"""
|
|
||||||
将数据集划分为多个客户端,每个客户端拥有独立的训练和验证数据。
|
|
||||||
"""
|
|
||||||
# 加载完整数据集
|
|
||||||
full_train_dataset = CustomImageDataset(root_dir=train_path, transform=get_transform("train"))
|
|
||||||
full_val_dataset = CustomImageDataset(root_dir=val_path, transform=get_transform("val"))
|
|
||||||
|
|
||||||
# 划分客户端训练集
|
|
||||||
client_train_datasets = random_split(full_train_dataset, [len(full_train_dataset) // num_clients] * num_clients)
|
|
||||||
|
|
||||||
# 创建客户端数据加载器
|
|
||||||
client_train_loaders = [
|
|
||||||
DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
|
||||||
for ds in client_train_datasets
|
|
||||||
]
|
|
||||||
|
|
||||||
# 全局验证集
|
|
||||||
global_val_loader = DataLoader(full_val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
||||||
|
|
||||||
return client_train_loaders, global_val_loader
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# 设置参数
|
# 设置参数
|
||||||
train_image_path = "/media/terminator/实验&代码/yhs/FF++_mask/c23/f2f/train"
|
train_image_path = "/media/terminator/实验&代码/yhs/FF++_mask/c23/f2f/train"
|
@ -55,11 +55,3 @@ def get_model(name, number_class, device, backbone):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Model {name} is not supported.")
|
raise ValueError(f"Model {name} is not supported.")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_federated_model(device):
|
|
||||||
"""初始化客户端模型和全局模型"""
|
|
||||||
client_models = [
|
|
||||||
get_model("resnet18_psa", 1, device, "*") for _ in range(3)
|
|
||||||
]
|
|
||||||
global_model = get_model("resnet18_psa", 1, device, "*")
|
|
||||||
return client_models, global_model
|
|
@ -116,7 +116,6 @@ def test_deepmodel(device, model, loader):
|
|||||||
# avg_loss = running_loss / len(loader)
|
# avg_loss = running_loss / len(loader)
|
||||||
# print(f'{model_name} Training Loss: {avg_loss:.4f}')
|
# print(f'{model_name} Training Loss: {avg_loss:.4f}')
|
||||||
# return avg_loss
|
# return avg_loss
|
||||||
|
|
||||||
def train_model(device, model, loader, optimizer, criterion, epoch, model_name):
|
def train_model(device, model, loader, optimizer, criterion, epoch, model_name):
|
||||||
model.train()
|
model.train()
|
||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
@ -332,7 +331,6 @@ def f_update_model_weights(
|
|||||||
updated_val_auc_threshold (float): 更新后的验证 AUC 阈值。
|
updated_val_auc_threshold (float): 更新后的验证 AUC 阈值。
|
||||||
"""
|
"""
|
||||||
# 每隔指定的 epoch 更新一次模型权重
|
# 每隔指定的 epoch 更新一次模型权重
|
||||||
|
|
||||||
if (epoch + 1) % update_frequency == 0:
|
if (epoch + 1) % update_frequency == 0:
|
||||||
print(f"\n[Epoch {epoch + 1}] Updating global model weights...")
|
print(f"\n[Epoch {epoch + 1}] Updating global model weights...")
|
||||||
|
|
@ -1,131 +0,0 @@
|
|||||||
import glob
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
from ultralytics import YOLO
|
|
||||||
import copy
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
# ------------ 新增联邦学习工具函数 ------------
|
|
||||||
def federated_avg(global_model, client_weights):
|
|
||||||
"""联邦平均核心算法"""
|
|
||||||
# 计算总样本数
|
|
||||||
total_samples = sum(n for _, n in client_weights)
|
|
||||||
if total_samples == 0:
|
|
||||||
raise ValueError("Total number of samples must be positive.")
|
|
||||||
|
|
||||||
# 获取YOLO底层PyTorch模型参数
|
|
||||||
global_dict = global_model.model.state_dict()
|
|
||||||
# 提取所有客户端的 state_dict 和对应样本数
|
|
||||||
state_dicts, sample_counts = zip(*client_weights)
|
|
||||||
|
|
||||||
for key in global_dict:
|
|
||||||
# 对每一层参数取平均
|
|
||||||
# if global_dict[key].data.dtype == torch.float32:
|
|
||||||
# global_dict[key].data = torch.stack(
|
|
||||||
# [w[key].float() for w in client_weights], 0
|
|
||||||
# ).mean(0)
|
|
||||||
|
|
||||||
# 加权平均
|
|
||||||
if global_dict[key].dtype == torch.float32: # 只聚合浮点型参数
|
|
||||||
# 跳过 BatchNorm 层的统计量
|
|
||||||
if any(x in key for x in ['running_mean', 'running_var', 'num_batches_tracked']):
|
|
||||||
continue
|
|
||||||
# 按照样本数加权求和
|
|
||||||
weighted_tensors = [sd[key].float() * (n / total_samples)
|
|
||||||
for sd, n in zip(state_dicts, sample_counts)]
|
|
||||||
global_dict[key] = torch.stack(weighted_tensors, dim=0).sum(dim=0)
|
|
||||||
|
|
||||||
# 解决模型参数不匹配问题
|
|
||||||
try:
|
|
||||||
# 加载回YOLO模型
|
|
||||||
global_model.model.load_state_dict(global_dict)
|
|
||||||
except RuntimeError as e:
|
|
||||||
print('Ignoring "' + str(e) + '"')
|
|
||||||
|
|
||||||
# 添加调试输出
|
|
||||||
print("\n=== 参数聚合检查 ===")
|
|
||||||
|
|
||||||
# 选取一个典型参数层
|
|
||||||
# sample_key = list(global_dict.keys())[10]
|
|
||||||
# original = global_dict[sample_key].data.mean().item()
|
|
||||||
# aggregated = torch.stack([w[sample_key] for w in client_weights]).mean().item()
|
|
||||||
# print(f"参数层 '{sample_key}' 变化: {original:.4f} → {aggregated:.4f}")
|
|
||||||
# print(f"客户端参数差异: {[w[sample_key].mean().item() for w in client_weights]}")
|
|
||||||
|
|
||||||
# 随机选取一个非统计量层进行对比
|
|
||||||
sample_key = next(k for k in global_dict if 'running_' not in k)
|
|
||||||
aggregated_mean = global_dict[sample_key].mean().item()
|
|
||||||
client_means = [sd[sample_key].float().mean().item() for sd in state_dicts]
|
|
||||||
print(f"层 '{sample_key}' 聚合后均值: {aggregated_mean:.6f}")
|
|
||||||
print(f"各客户端该层均值: {client_means}")
|
|
||||||
|
|
||||||
return global_model
|
|
||||||
|
|
||||||
|
|
||||||
# ------------ 修改训练流程 ------------
|
|
||||||
def federated_train(num_rounds, clients_data):
|
|
||||||
# 初始化全局模型
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
global_model = YOLO("yolov8n.pt").to(device)
|
|
||||||
# 设置类别数
|
|
||||||
global_model.model.nc = 2
|
|
||||||
|
|
||||||
for _ in range(num_rounds):
|
|
||||||
client_weights = []
|
|
||||||
|
|
||||||
# 每个客户端本地训练
|
|
||||||
for data_path in clients_data:
|
|
||||||
# 统计本地训练样本数
|
|
||||||
with open(data_path, 'r') as f:
|
|
||||||
config = yaml.safe_load(f)
|
|
||||||
# Resolve img_dir relative to the YAML file's location
|
|
||||||
yaml_dir = os.path.dirname(data_path)
|
|
||||||
img_dir = os.path.join(yaml_dir, config.get('train', data_path)) # 从配置文件中获取图像目录
|
|
||||||
|
|
||||||
print(f"Image directory: {img_dir}")
|
|
||||||
num_samples = len(glob.glob(os.path.join(img_dir, '*.jpg')))
|
|
||||||
print(f"Number of images: {num_samples}")
|
|
||||||
|
|
||||||
# 克隆全局模型
|
|
||||||
local_model = copy.deepcopy(global_model)
|
|
||||||
|
|
||||||
# 本地训练(保持你的原有参数设置)
|
|
||||||
local_model.train(
|
|
||||||
data=data_path,
|
|
||||||
epochs=1, # 每轮本地训练1个epoch
|
|
||||||
imgsz=128, # 图像大小
|
|
||||||
verbose=False # 关闭冗余输出
|
|
||||||
)
|
|
||||||
|
|
||||||
# 收集模型参数及样本数
|
|
||||||
client_weights.append((copy.deepcopy(local_model.model.state_dict()), num_samples))
|
|
||||||
|
|
||||||
# 聚合参数更新全局模型
|
|
||||||
global_model = federated_avg(global_model, client_weights)
|
|
||||||
|
|
||||||
return global_model
|
|
||||||
|
|
||||||
|
|
||||||
# ------------ 使用示例 ------------
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 联邦训练配置
|
|
||||||
clients_config = [
|
|
||||||
"./config/client1_data.yaml", # 客户端1数据路径
|
|
||||||
"./config/client2_data.yaml" # 客户端2数据路径
|
|
||||||
]
|
|
||||||
|
|
||||||
# 运行联邦训练
|
|
||||||
final_model = federated_train(num_rounds=1, clients_data=clients_config)
|
|
||||||
|
|
||||||
# 保存最终模型
|
|
||||||
# final_model.export(format="onnx") # 导出为ONNX格式
|
|
||||||
|
|
||||||
# 检查1:确认模型保存
|
|
||||||
# assert Path("yolov8n_federated.onnx").exists(), "模型导出失败"
|
|
||||||
|
|
||||||
# 检查2:验证预测功能
|
|
||||||
# results = final_model.predict("test_data/client1/train/images/img1.jpg")
|
|
||||||
# assert len(results[0].boxes) > 0, "预测结果异常"
|
|
@ -166,11 +166,11 @@ def main(matchimg_vi, matchimg_in):
|
|||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
# 输入可见光和红外图像路径
|
# 输入可见光和红外图像路径
|
||||||
visible_image_path = "test/visible.jpg" # 可见光图片路径
|
visible_image_path = "../test/visible.jpg" # 可见光图片路径
|
||||||
infrared_image_path = "test/infrared.jpg" # 红外图片路径
|
infrared_image_path = "../test/infrared.jpg" # 红外图片路径
|
||||||
# 输入可见光和红外视频路径
|
# 输入可见光和红外视频路径
|
||||||
visible_video_path = "test/visible.mp4" # 可见光视频路径
|
visible_video_path = "../test/visible.mp4" # 可见光视频路径
|
||||||
infrared_video_path = "test/infrared.mp4" # 红外视频路径
|
infrared_video_path = "../test/infrared.mp4" # 红外视频路径
|
||||||
|
|
||||||
"""解析命令行参数"""
|
"""解析命令行参数"""
|
||||||
parser = argparse.ArgumentParser(description='图像融合与目标检测')
|
parser = argparse.ArgumentParser(description='图像融合与目标检测')
|
||||||
@ -277,7 +277,7 @@ if __name__ == '__main__':
|
|||||||
if flag == 1:
|
if flag == 1:
|
||||||
# 显示并保存结果
|
# 显示并保存结果
|
||||||
cv2.imshow("Fusion with Detection", fusion_result)
|
cv2.imshow("Fusion with Detection", fusion_result)
|
||||||
cv2.imwrite("output/fusion_result.jpg", fusion_result)
|
cv2.imwrite("../output/fusion_result.jpg", fusion_result)
|
||||||
cv2.waitKey(0)
|
cv2.waitKey(0)
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
else:
|
else:
|
||||||
|
Before Width: | Height: | Size: 152 KiB After Width: | Height: | Size: 152 KiB |
Before Width: | Height: | Size: 28 KiB After Width: | Height: | Size: 28 KiB |
Before Width: | Height: | Size: 67 KiB After Width: | Height: | Size: 67 KiB |