修改参数,符合Linux路径要求
This commit is contained in:
		| @@ -1,4 +1,4 @@ | |||||||
| train: images | train: ./images | ||||||
| val:   ../val | val:   ../val | ||||||
| nc:    1 | nc:    1 | ||||||
| names: ['uav'] | names: ['uav'] | ||||||
| @@ -1,4 +1,4 @@ | |||||||
| train: images | train: ./images | ||||||
| val:   ../val | val:   ../val | ||||||
| nc:    1 | nc:    1 | ||||||
| names: ['uav'] | names: ['uav'] | ||||||
|   | |||||||
| @@ -59,8 +59,8 @@ def federated_avg(global_model, client_weights): | |||||||
|     sample_key = next(k for k in global_dict if 'running_' not in k) |     sample_key = next(k for k in global_dict if 'running_' not in k) | ||||||
|     aggregated_mean = global_dict[sample_key].mean().item() |     aggregated_mean = global_dict[sample_key].mean().item() | ||||||
|     client_means = [sd[sample_key].float().mean().item() for sd in state_dicts] |     client_means = [sd[sample_key].float().mean().item() for sd in state_dicts] | ||||||
|     print(f"层 '{sample_key}' 聚合后均值: {aggregated_mean:.6f}") |     print(f"layer: '{sample_key}' Mean after aggregation: {aggregated_mean:.6f}") | ||||||
|     print(f"各客户端该层均值: {client_means}") |     print(f"The average value of the layer for each client: {client_means}") | ||||||
|      |      | ||||||
|     return global_model |     return global_model | ||||||
|  |  | ||||||
| @@ -85,10 +85,10 @@ def federated_train(num_rounds, clients_data): | |||||||
|             yaml_dir = os.path.dirname(data_path) |             yaml_dir = os.path.dirname(data_path) | ||||||
|             img_dir = os.path.join(yaml_dir, config.get('train', data_path))  # 从配置文件中获取图像目录 |             img_dir = os.path.join(yaml_dir, config.get('train', data_path))  # 从配置文件中获取图像目录 | ||||||
|              |              | ||||||
|             print(f"Image directory: {img_dir}") |             # print(f"Image directory: {img_dir}") | ||||||
|             num_samples = (len(glob.glob(os.path.join(img_dir, '*.jpg'))) + |             num_samples = (len(glob.glob(os.path.join(img_dir, '*.jpg'))) + | ||||||
|                            len(glob.glob(os.path.join(img_dir, '*.png')))) |                            len(glob.glob(os.path.join(img_dir, '*.png')))) | ||||||
|             print(f"Number of images: {num_samples}") |             # print(f"Number of images: {num_samples}") | ||||||
|              |              | ||||||
|             # 克隆全局模型 |             # 克隆全局模型 | ||||||
|             local_model = copy.deepcopy(global_model) |             local_model = copy.deepcopy(global_model) | ||||||
| @@ -96,9 +96,11 @@ def federated_train(num_rounds, clients_data): | |||||||
|             # 本地训练(保持你的原有参数设置) |             # 本地训练(保持你的原有参数设置) | ||||||
|             local_model.train( |             local_model.train( | ||||||
|                 data=data_path, |                 data=data_path, | ||||||
|                 epochs=1,  # 每轮本地训练1个epoch |                 epochs=16,  # 每轮本地训练1个epoch | ||||||
|  |                 save_period=16, | ||||||
|                 imgsz=640,  # 图像大小 |                 imgsz=640,  # 图像大小 | ||||||
|                 verbose=False  # 关闭冗余输出 |                 verbose=False,  # 关闭冗余输出 | ||||||
|  |                 batch=-1 | ||||||
|             ) |             ) | ||||||
|              |              | ||||||
|             # 收集模型参数及样本数 |             # 收集模型参数及样本数 | ||||||
| @@ -106,7 +108,7 @@ def federated_train(num_rounds, clients_data): | |||||||
|          |          | ||||||
|         # 聚合参数更新全局模型 |         # 聚合参数更新全局模型 | ||||||
|         global_model = federated_avg(global_model, client_weights) |         global_model = federated_avg(global_model, client_weights) | ||||||
|      |         print(f"Round {_ + 1}/{num_rounds} completed.") | ||||||
|     return global_model |     return global_model | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -114,14 +116,15 @@ def federated_train(num_rounds, clients_data): | |||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     # 联邦训练配置 |     # 联邦训练配置 | ||||||
|     clients_config = [ |     clients_config = [ | ||||||
|         "../dataset/train1/train1.yaml",  # 客户端1数据路径 |         "/root/autodl-tmp/dataset/train1/train1.yaml",  # 客户端1数据路径 | ||||||
|         "../dataset/train2/train2.yaml"  # 客户端2数据路径 |         "/root/autodl-tmp/dataset/train2/train2.yaml"  # 客户端2数据路径 | ||||||
|     ] |     ] | ||||||
|      |      | ||||||
|     # 运行联邦训练 |     # 运行联邦训练 | ||||||
|     final_model = federated_train(num_rounds=1, clients_data=clients_config) |     final_model = federated_train(num_rounds=10, clients_data=clients_config) | ||||||
|      |      | ||||||
|     # 保存最终模型 |     # 保存最终模型 | ||||||
|  |     final_model.save("yolov8n_federated.pt") | ||||||
|     # final_model.export(format="onnx")  # 导出为ONNX格式 |     # final_model.export(format="onnx")  # 导出为ONNX格式 | ||||||
|      |      | ||||||
|     # 检查1:确认模型保存 |     # 检查1:确认模型保存 | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user