From 3778850d029307e97d17a7ffbc5acfc803132b2b Mon Sep 17 00:00:00 2001 From: "hongliang.yuan" Date: Mon, 1 Sep 2025 09:32:58 +0800 Subject: [PATCH] sync 2025-09-01 resnet code --- .../pytorch/dataloader/dali_classification.py | 24 ++++++++++++------- .../pytorch/dataloader/dali_classification.py | 2 +- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/cv/classification/resnet101/pytorch/dataloader/dali_classification.py b/cv/classification/resnet101/pytorch/dataloader/dali_classification.py index faf9c8cbe..3147aa400 100644 --- a/cv/classification/resnet101/pytorch/dataloader/dali_classification.py +++ b/cv/classification/resnet101/pytorch/dataloader/dali_classification.py @@ -15,10 +15,12 @@ import nvidia.dali.types as types from nvidia.dali.pipeline import Pipeline from nvidia.dali.plugin.pytorch import DALIClassificationIterator, DALIGenericIterator from nvidia.dali.plugin.base_iterator import LastBatchPolicy +import torch.distributed as dist + class HybridTrainPipe(Pipeline): - def __init__(self, batch_size, num_threads, device_id, data_dir, size): + def __init__(self, batch_size, num_threads, device_id, data_dir, size, shard_id, num_shards): super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id) - self.input = ops.FileReader(file_root=data_dir, random_shuffle=True) + self.input = ops.FileReader(file_root=data_dir, random_shuffle=True, shard_id=shard_id, num_shards=num_shards) self.decode = ops.ImageDecoder(device="cpu", output_type=types.RGB) self.res = ops.RandomResizedCrop(device="gpu", size=size, random_area=[0.08, 1.25]) self.cmnp = ops.CropMirrorNormalize(device="gpu", @@ -38,9 +40,9 @@ class HybridTrainPipe(Pipeline): class HybridValPipe(Pipeline): - def __init__(self, batch_size, num_threads, device_id, data_dir, size): + def __init__(self, batch_size, num_threads, device_id, data_dir, size, shard_id, num_shards): super(HybridValPipe, self).__init__(batch_size, num_threads, device_id) - self.input = ops.FileReader(file_root=data_dir, random_shuffle=False) + self.input = ops.FileReader(file_root=data_dir, random_shuffle=False, shard_id=shard_id, num_shards=num_shards) self.decode = ops.ImageDecoder(device="cpu", output_type=types.RGB) self.res = ops.Resize(device="gpu", resize_x=size, resize_y=size) self.cmnp = ops.CropMirrorNormalize(device="gpu", @@ -61,19 +63,25 @@ class HybridValPipe(Pipeline): def get_imagenet_iter_dali(type, image_dir, batch_size, num_threads, device_id, size): + if dist.is_initialized(): + shard_id = dist.get_rank() + num_shards = dist.get_world_size() + else: + shard_id = 0 + num_shards = 1 if type == 'train': pip_train = HybridTrainPipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id, data_dir = os.path.join(image_dir, "train"), - size=size) + size=size, shard_id=shard_id, num_shards=num_shards) pip_train.build() - dali_iter_train = DALIClassificationIterator(pip_train, size=pip_train.epoch_size("Reader"), last_batch_policy = LastBatchPolicy.DROP) + dali_iter_train = DALIClassificationIterator(pip_train, size=pip_train.epoch_size("Reader")//num_shards, last_batch_policy = LastBatchPolicy.DROP) return dali_iter_train elif type == 'val': pip_val = HybridValPipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id, data_dir = os.path.join(image_dir, "val"), - size=size) + size=size, shard_id=shard_id, num_shards=num_shards) pip_val.build() - dali_iter_val = DALIClassificationIterator(pip_val, size=pip_val.epoch_size("Reader"), last_batch_policy = LastBatchPolicy.DROP) + dali_iter_val = DALIClassificationIterator(pip_val, size=pip_val.epoch_size("Reader")//num_shards, last_batch_policy = LastBatchPolicy.DROP) return dali_iter_val diff --git a/cv/classification/resnet50/pytorch/dataloader/dali_classification.py b/cv/classification/resnet50/pytorch/dataloader/dali_classification.py index b10beb753..3147aa400 100644 --- a/cv/classification/resnet50/pytorch/dataloader/dali_classification.py +++ b/cv/classification/resnet50/pytorch/dataloader/dali_classification.py @@ -126,4 +126,4 @@ def main(arguments): if __name__ == '__main__': import os, time, sys import argparse - sys.exit(main(sys.argv[1:])) + sys.exit(main(sys.argv[1:])) \ No newline at end of file -- Gitee