更改最小测试示例
This commit is contained in:
@@ -71,7 +71,7 @@ def federated_train(num_rounds, clients_data):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
global_model = YOLO("yolov8n.pt").to(device)
|
||||
# 设置类别数
|
||||
global_model.model.nc = 2
|
||||
# global_model.model.nc = 2
|
||||
|
||||
for _ in range(num_rounds):
|
||||
client_weights = []
|
||||
@@ -96,7 +96,7 @@ def federated_train(num_rounds, clients_data):
|
||||
local_model.train(
|
||||
data=data_path,
|
||||
epochs=1, # 每轮本地训练1个epoch
|
||||
imgsz=128, # 图像大小
|
||||
imgsz=640, # 图像大小
|
||||
verbose=False # 关闭冗余输出
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user