修改参数,符合Linux路径要求
This commit is contained in:
parent
9f827af58e
commit
69482e6a3f
@ -1,4 +1,4 @@
|
|||||||
train: images
|
train: ./images
|
||||||
val: ../val
|
val: ../val
|
||||||
nc: 1
|
nc: 1
|
||||||
names: ['uav']
|
names: ['uav']
|
@ -1,4 +1,4 @@
|
|||||||
train: images
|
train: ./images
|
||||||
val: ../val
|
val: ../val
|
||||||
nc: 1
|
nc: 1
|
||||||
names: ['uav']
|
names: ['uav']
|
||||||
|
@ -59,8 +59,8 @@ def federated_avg(global_model, client_weights):
|
|||||||
sample_key = next(k for k in global_dict if 'running_' not in k)
|
sample_key = next(k for k in global_dict if 'running_' not in k)
|
||||||
aggregated_mean = global_dict[sample_key].mean().item()
|
aggregated_mean = global_dict[sample_key].mean().item()
|
||||||
client_means = [sd[sample_key].float().mean().item() for sd in state_dicts]
|
client_means = [sd[sample_key].float().mean().item() for sd in state_dicts]
|
||||||
print(f"层 '{sample_key}' 聚合后均值: {aggregated_mean:.6f}")
|
print(f"layer: '{sample_key}' Mean after aggregation: {aggregated_mean:.6f}")
|
||||||
print(f"各客户端该层均值: {client_means}")
|
print(f"The average value of the layer for each client: {client_means}")
|
||||||
|
|
||||||
return global_model
|
return global_model
|
||||||
|
|
||||||
@ -85,10 +85,10 @@ def federated_train(num_rounds, clients_data):
|
|||||||
yaml_dir = os.path.dirname(data_path)
|
yaml_dir = os.path.dirname(data_path)
|
||||||
img_dir = os.path.join(yaml_dir, config.get('train', data_path)) # 从配置文件中获取图像目录
|
img_dir = os.path.join(yaml_dir, config.get('train', data_path)) # 从配置文件中获取图像目录
|
||||||
|
|
||||||
print(f"Image directory: {img_dir}")
|
# print(f"Image directory: {img_dir}")
|
||||||
num_samples = (len(glob.glob(os.path.join(img_dir, '*.jpg'))) +
|
num_samples = (len(glob.glob(os.path.join(img_dir, '*.jpg'))) +
|
||||||
len(glob.glob(os.path.join(img_dir, '*.png'))))
|
len(glob.glob(os.path.join(img_dir, '*.png'))))
|
||||||
print(f"Number of images: {num_samples}")
|
# print(f"Number of images: {num_samples}")
|
||||||
|
|
||||||
# 克隆全局模型
|
# 克隆全局模型
|
||||||
local_model = copy.deepcopy(global_model)
|
local_model = copy.deepcopy(global_model)
|
||||||
@ -96,9 +96,11 @@ def federated_train(num_rounds, clients_data):
|
|||||||
# 本地训练(保持你的原有参数设置)
|
# 本地训练(保持你的原有参数设置)
|
||||||
local_model.train(
|
local_model.train(
|
||||||
data=data_path,
|
data=data_path,
|
||||||
epochs=1, # 每轮本地训练1个epoch
|
epochs=16, # 每轮本地训练1个epoch
|
||||||
|
save_period=16,
|
||||||
imgsz=640, # 图像大小
|
imgsz=640, # 图像大小
|
||||||
verbose=False # 关闭冗余输出
|
verbose=False, # 关闭冗余输出
|
||||||
|
batch=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
# 收集模型参数及样本数
|
# 收集模型参数及样本数
|
||||||
@ -106,7 +108,7 @@ def federated_train(num_rounds, clients_data):
|
|||||||
|
|
||||||
# 聚合参数更新全局模型
|
# 聚合参数更新全局模型
|
||||||
global_model = federated_avg(global_model, client_weights)
|
global_model = federated_avg(global_model, client_weights)
|
||||||
|
print(f"Round {_ + 1}/{num_rounds} completed.")
|
||||||
return global_model
|
return global_model
|
||||||
|
|
||||||
|
|
||||||
@ -114,14 +116,15 @@ def federated_train(num_rounds, clients_data):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 联邦训练配置
|
# 联邦训练配置
|
||||||
clients_config = [
|
clients_config = [
|
||||||
"../dataset/train1/train1.yaml", # 客户端1数据路径
|
"/root/autodl-tmp/dataset/train1/train1.yaml", # 客户端1数据路径
|
||||||
"../dataset/train2/train2.yaml" # 客户端2数据路径
|
"/root/autodl-tmp/dataset/train2/train2.yaml" # 客户端2数据路径
|
||||||
]
|
]
|
||||||
|
|
||||||
# 运行联邦训练
|
# 运行联邦训练
|
||||||
final_model = federated_train(num_rounds=1, clients_data=clients_config)
|
final_model = federated_train(num_rounds=10, clients_data=clients_config)
|
||||||
|
|
||||||
# 保存最终模型
|
# 保存最终模型
|
||||||
|
final_model.save("yolov8n_federated.pt")
|
||||||
# final_model.export(format="onnx") # 导出为ONNX格式
|
# final_model.export(format="onnx") # 导出为ONNX格式
|
||||||
|
|
||||||
# 检查1:确认模型保存
|
# 检查1:确认模型保存
|
||||||
|
Loading…
Reference in New Issue
Block a user