From 338a5e07e86e628b4b94bf900f98ca2821f008e5 Mon Sep 17 00:00:00 2001 From: myh Date: Tue, 22 Apr 2025 00:19:43 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=8F=82=E6=95=B0=EF=BC=8C?= =?UTF-8?q?=E4=BD=BF=E5=85=B6=E7=AC=A6=E5=90=88=E8=AE=AD=E7=BB=83=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E9=9B=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- federated_learning/yolov8_fed.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/federated_learning/yolov8_fed.py b/federated_learning/yolov8_fed.py index 8b82916..53edff0 100644 --- a/federated_learning/yolov8_fed.py +++ b/federated_learning/yolov8_fed.py @@ -71,7 +71,7 @@ def federated_train(num_rounds, clients_data): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") global_model = YOLO("yolov8n.pt").to(device) # 设置类别数 - # global_model.model.nc = 2 + global_model.model.nc = 1 for _ in range(num_rounds): client_weights = [] @@ -86,7 +86,8 @@ def federated_train(num_rounds, clients_data): img_dir = os.path.join(yaml_dir, config.get('train', data_path)) # 从配置文件中获取图像目录 print(f"Image directory: {img_dir}") - num_samples = len(glob.glob(os.path.join(img_dir, '*.jpg'))) + num_samples = (len(glob.glob(os.path.join(img_dir, '*.jpg'))) + + len(glob.glob(os.path.join(img_dir, '*.png')))) print(f"Number of images: {num_samples}") # 克隆全局模型 @@ -127,5 +128,5 @@ if __name__ == "__main__": # assert Path("yolov8n_federated.onnx").exists(), "模型导出失败" # 检查2:验证预测功能 - # results = final_model.predict("test_data/client1/train/images/img1.jpg") + # results = final_model.predict("../dataset/val/images/VS_P65.jpg", save=True) # assert len(results[0].boxes) > 0, "预测结果异常"