diff --git a/federated_learning/yolov8_fed.py b/federated_learning/yolov8_fed.py index 8b82916..53edff0 100644 --- a/federated_learning/yolov8_fed.py +++ b/federated_learning/yolov8_fed.py @@ -71,7 +71,7 @@ def federated_train(num_rounds, clients_data): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") global_model = YOLO("yolov8n.pt").to(device) # 设置类别数 - # global_model.model.nc = 2 + global_model.model.nc = 1 for _ in range(num_rounds): client_weights = [] @@ -86,7 +86,8 @@ def federated_train(num_rounds, clients_data): img_dir = os.path.join(yaml_dir, config.get('train', data_path)) # 从配置文件中获取图像目录 print(f"Image directory: {img_dir}") - num_samples = len(glob.glob(os.path.join(img_dir, '*.jpg'))) + num_samples = (len(glob.glob(os.path.join(img_dir, '*.jpg'))) + + len(glob.glob(os.path.join(img_dir, '*.png')))) print(f"Number of images: {num_samples}") # 克隆全局模型 @@ -127,5 +128,5 @@ if __name__ == "__main__": # assert Path("yolov8n_federated.onnx").exists(), "模型导出失败" # 检查2:验证预测功能 - # results = final_model.predict("test_data/client1/train/images/img1.jpg") + # results = final_model.predict("../dataset/val/images/VS_P65.jpg", save=True) # assert len(results[0].boxes) > 0, "预测结果异常"