修改参数,使其符合训练数据集
This commit is contained in:
parent
9d99b00e55
commit
338a5e07e8
@ -71,7 +71,7 @@ def federated_train(num_rounds, clients_data):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
global_model = YOLO("yolov8n.pt").to(device)
|
||||
# 设置类别数
|
||||
# global_model.model.nc = 2
|
||||
global_model.model.nc = 1
|
||||
|
||||
for _ in range(num_rounds):
|
||||
client_weights = []
|
||||
@ -86,7 +86,8 @@ def federated_train(num_rounds, clients_data):
|
||||
img_dir = os.path.join(yaml_dir, config.get('train', data_path)) # 从配置文件中获取图像目录
|
||||
|
||||
print(f"Image directory: {img_dir}")
|
||||
num_samples = len(glob.glob(os.path.join(img_dir, '*.jpg')))
|
||||
num_samples = (len(glob.glob(os.path.join(img_dir, '*.jpg'))) +
|
||||
len(glob.glob(os.path.join(img_dir, '*.png'))))
|
||||
print(f"Number of images: {num_samples}")
|
||||
|
||||
# 克隆全局模型
|
||||
@ -127,5 +128,5 @@ if __name__ == "__main__":
|
||||
# assert Path("yolov8n_federated.onnx").exists(), "模型导出失败"
|
||||
|
||||
# 检查2:验证预测功能
|
||||
# results = final_model.predict("test_data/client1/train/images/img1.jpg")
|
||||
# results = final_model.predict("../dataset/val/images/VS_P65.jpg", save=True)
|
||||
# assert len(results[0].boxes) > 0, "预测结果异常"
|
||||
|
Loading…
Reference in New Issue
Block a user