19 lines
600 B
Python
19 lines
600 B
Python
![]() |
import argparse
|
||
|
import os
|
||
|
|
||
|
|
||
|
def args_parser():
|
||
|
parser = argparse.ArgumentParser()
|
||
|
|
||
|
parser.add_argument("--epochs", type=int, default=10, help="number of rounds of local training")
|
||
|
parser.add_argument("--input_size", type=int, default=640, help="image input size")
|
||
|
parser.add_argument("--config", type=str, default="./config/uav_cfg.yaml", help="Path to YAML config")
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
args.local_rank = int(os.getenv("LOCAL_RANK", 0))
|
||
|
args.world_size = int(os.getenv("WORLD_SIZE", 1))
|
||
|
args.distributed = int(os.getenv("WORLD_SIZE", 1)) > 1
|
||
|
|
||
|
return args
|