修改参数,使其符合训练数据集

This commit is contained in:
myh 2025-04-22 00:19:43 +08:00
parent 9d99b00e55
commit 338a5e07e8

View File

@ -71,7 +71,7 @@ def federated_train(num_rounds, clients_data):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
global_model = YOLO("yolov8n.pt").to(device) global_model = YOLO("yolov8n.pt").to(device)
# 设置类别数 # 设置类别数
# global_model.model.nc = 2 global_model.model.nc = 1
for _ in range(num_rounds): for _ in range(num_rounds):
client_weights = [] client_weights = []
@ -86,7 +86,8 @@ def federated_train(num_rounds, clients_data):
img_dir = os.path.join(yaml_dir, config.get('train', data_path)) # 从配置文件中获取图像目录 img_dir = os.path.join(yaml_dir, config.get('train', data_path)) # 从配置文件中获取图像目录
print(f"Image directory: {img_dir}") print(f"Image directory: {img_dir}")
num_samples = len(glob.glob(os.path.join(img_dir, '*.jpg'))) num_samples = (len(glob.glob(os.path.join(img_dir, '*.jpg'))) +
len(glob.glob(os.path.join(img_dir, '*.png'))))
print(f"Number of images: {num_samples}") print(f"Number of images: {num_samples}")
# 克隆全局模型 # 克隆全局模型
@ -127,5 +128,5 @@ if __name__ == "__main__":
# assert Path("yolov8n_federated.onnx").exists(), "模型导出失败" # assert Path("yolov8n_federated.onnx").exists(), "模型导出失败"
# 检查2验证预测功能 # 检查2验证预测功能
# results = final_model.predict("test_data/client1/train/images/img1.jpg") # results = final_model.predict("../dataset/val/images/VS_P65.jpg", save=True)
# assert len(results[0].boxes) > 0, "预测结果异常" # assert len(results[0].boxes) > 0, "预测结果异常"