Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Tsaris, Aristeidis (aris)
pytorch_tutorial
Commits
9dc09f0d
Commit
9dc09f0d
authored
Sep 24, 2021
by
Tsaris, Aristeidis
Browse files
adding some code
parent
b5250fc3
Changes
19
Hide whitespace changes
Inline
Side-by-side
README.md
View file @
9dc09f0d
#
pytorch
_tutorial
#
imagenet
_tutorial
This code is from
[
NVIDIA-DeepLearningExamples
](
https://github.com/NVIDIA/DeepLearningExamples
)
with some modifications.
\ No newline at end of file
imagenet/image_classification/__init__.py
0 → 100644
View file @
9dc09f0d
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#from . import logger
#from . import dataloaders
#from . import training
#from . import utils
#from . import mixup
#from . import smoothing
from
.
import
models
imagenet/image_classification/autoaugment.py
0 → 100644
View file @
9dc09f0d
from
PIL
import
Image
,
ImageEnhance
,
ImageOps
import
numpy
as
np
import
random
class
AutoaugmentImageNetPolicy
(
object
):
"""
Randomly choose one of the best 24 Sub-policies on ImageNet.
Reference: https://arxiv.org/abs/1805.09501
"""
def
__init__
(
self
):
self
.
policies
=
[
SubPolicy
(
0.8
,
"equalize"
,
1
,
0.8
,
"shearY"
,
4
),
SubPolicy
(
0.4
,
"color"
,
9
,
0.6
,
"equalize"
,
3
),
SubPolicy
(
0.4
,
"color"
,
1
,
0.6
,
"rotate"
,
8
),
SubPolicy
(
0.8
,
"solarize"
,
3
,
0.4
,
"equalize"
,
7
),
SubPolicy
(
0.4
,
"solarize"
,
2
,
0.6
,
"solarize"
,
2
),
SubPolicy
(
0.2
,
"color"
,
0
,
0.8
,
"equalize"
,
8
),
SubPolicy
(
0.4
,
"equalize"
,
8
,
0.8
,
"solarizeadd"
,
3
),
SubPolicy
(
0.2
,
"shearX"
,
9
,
0.6
,
"rotate"
,
8
),
SubPolicy
(
0.6
,
"color"
,
1
,
1.0
,
"equalize"
,
2
),
SubPolicy
(
0.4
,
"invert"
,
9
,
0.6
,
"rotate"
,
0
),
SubPolicy
(
1.0
,
"equalize"
,
9
,
0.6
,
"shearY"
,
3
),
SubPolicy
(
0.4
,
"color"
,
7
,
0.6
,
"equalize"
,
0
),
SubPolicy
(
0.4
,
"posterize"
,
6
,
0.4
,
"autocontrast"
,
7
),
SubPolicy
(
0.6
,
"solarize"
,
8
,
0.6
,
"color"
,
9
),
SubPolicy
(
0.2
,
"solarize"
,
4
,
0.8
,
"rotate"
,
9
),
SubPolicy
(
1.0
,
"rotate"
,
7
,
0.8
,
"translateY"
,
9
),
SubPolicy
(
0.0
,
"shearX"
,
0
,
0.8
,
"solarize"
,
4
),
SubPolicy
(
0.8
,
"shearY"
,
0
,
0.6
,
"color"
,
4
),
SubPolicy
(
1.0
,
"color"
,
0
,
0.6
,
"rotate"
,
2
),
SubPolicy
(
0.8
,
"equalize"
,
4
,
0.0
,
"equalize"
,
8
),
SubPolicy
(
1.0
,
"equalize"
,
4
,
0.6
,
"autocontrast"
,
2
),
SubPolicy
(
0.4
,
"shearY"
,
7
,
0.6
,
"solarizeadd"
,
7
),
SubPolicy
(
0.8
,
"posterize"
,
2
,
0.6
,
"solarize"
,
10
),
SubPolicy
(
0.6
,
"solarize"
,
8
,
0.6
,
"equalize"
,
1
),
SubPolicy
(
0.8
,
"color"
,
6
,
0.4
,
"rotate"
,
5
),
]
def
__call__
(
self
,
img
):
policy_idx
=
random
.
randint
(
0
,
len
(
self
.
policies
)
-
1
)
return
self
.
policies
[
policy_idx
](
img
)
def
__repr__
(
self
):
return
"AutoAugment ImageNet Policy"
class
SubPolicy
(
object
):
def
__init__
(
self
,
p1
,
method1
,
magnitude_idx1
,
p2
,
method2
,
magnitude_idx2
):
operation_factory
=
OperationFactory
()
self
.
p1
=
p1
self
.
p2
=
p2
self
.
operation1
=
operation_factory
.
get_operation
(
method1
,
magnitude_idx1
)
self
.
operation2
=
operation_factory
.
get_operation
(
method2
,
magnitude_idx2
)
def
__call__
(
self
,
img
):
if
random
.
random
()
<
self
.
p1
:
img
=
self
.
operation1
(
img
)
if
random
.
random
()
<
self
.
p2
:
img
=
self
.
operation2
(
img
)
return
img
class
OperationFactory
:
def
__init__
(
self
):
fillcolor
=
(
128
,
128
,
128
)
self
.
ranges
=
{
"shearX"
:
np
.
linspace
(
0
,
0.3
,
11
),
"shearY"
:
np
.
linspace
(
0
,
0.3
,
11
),
"translateX"
:
np
.
linspace
(
0
,
250
,
11
),
"translateY"
:
np
.
linspace
(
0
,
250
,
11
),
"rotate"
:
np
.
linspace
(
0
,
30
,
11
),
"color"
:
np
.
linspace
(
0.1
,
1.9
,
11
),
"posterize"
:
np
.
round
(
np
.
linspace
(
0
,
4
,
11
),
0
).
astype
(
np
.
int
),
"solarize"
:
np
.
linspace
(
0
,
256
,
11
),
"solarizeadd"
:
np
.
linspace
(
0
,
110
,
11
),
"contrast"
:
np
.
linspace
(
0.1
,
1.9
,
11
),
"sharpness"
:
np
.
linspace
(
0.1
,
1.9
,
11
),
"brightness"
:
np
.
linspace
(
0.1
,
1.9
,
11
),
"autocontrast"
:
[
0
]
*
10
,
"equalize"
:
[
0
]
*
10
,
"invert"
:
[
0
]
*
10
}
def
rotate_with_fill
(
img
,
magnitude
):
magnitude
*=
random
.
choice
([
-
1
,
1
])
rot
=
img
.
convert
(
"RGBA"
).
rotate
(
magnitude
)
return
Image
.
composite
(
rot
,
Image
.
new
(
"RGBA"
,
rot
.
size
,
(
128
,)
*
4
),
rot
).
convert
(
img
.
mode
)
def
solarize_add
(
image
,
addition
=
0
,
threshold
=
128
):
lut
=
[]
for
i
in
range
(
256
):
if
i
<
threshold
:
res
=
i
+
addition
if
i
+
addition
<=
255
else
255
res
=
res
if
res
>=
0
else
0
lut
.
append
(
res
)
else
:
lut
.
append
(
i
)
from
PIL.ImageOps
import
_lut
return
_lut
(
image
,
lut
)
self
.
operations
=
{
"shearX"
:
lambda
img
,
magnitude
:
img
.
transform
(
img
.
size
,
Image
.
AFFINE
,
(
1
,
magnitude
*
random
.
choice
([
-
1
,
1
]),
0
,
0
,
1
,
0
),
Image
.
BICUBIC
,
fillcolor
=
fillcolor
),
"shearY"
:
lambda
img
,
magnitude
:
img
.
transform
(
img
.
size
,
Image
.
AFFINE
,
(
1
,
0
,
0
,
magnitude
*
random
.
choice
([
-
1
,
1
]),
1
,
0
),
Image
.
BICUBIC
,
fillcolor
=
fillcolor
),
"translateX"
:
lambda
img
,
magnitude
:
img
.
transform
(
img
.
size
,
Image
.
AFFINE
,
(
1
,
0
,
magnitude
*
random
.
choice
([
-
1
,
1
]),
0
,
1
,
0
),
fillcolor
=
fillcolor
),
"translateY"
:
lambda
img
,
magnitude
:
img
.
transform
(
img
.
size
,
Image
.
AFFINE
,
(
1
,
0
,
0
,
0
,
1
,
magnitude
*
random
.
choice
([
-
1
,
1
])),
fillcolor
=
fillcolor
),
"rotate"
:
lambda
img
,
magnitude
:
rotate_with_fill
(
img
,
magnitude
),
"color"
:
lambda
img
,
magnitude
:
ImageEnhance
.
Color
(
img
).
enhance
(
magnitude
),
"posterize"
:
lambda
img
,
magnitude
:
ImageOps
.
posterize
(
img
,
magnitude
),
"solarize"
:
lambda
img
,
magnitude
:
ImageOps
.
solarize
(
img
,
magnitude
),
"solarizeadd"
:
lambda
img
,
magnitude
:
solarize_add
(
img
,
magnitude
),
"contrast"
:
lambda
img
,
magnitude
:
ImageEnhance
.
Contrast
(
img
).
enhance
(
magnitude
),
"sharpness"
:
lambda
img
,
magnitude
:
ImageEnhance
.
Sharpness
(
img
).
enhance
(
magnitude
),
"brightness"
:
lambda
img
,
magnitude
:
ImageEnhance
.
Brightness
(
img
).
enhance
(
magnitude
),
"autocontrast"
:
lambda
img
,
_
:
ImageOps
.
autocontrast
(
img
),
"equalize"
:
lambda
img
,
_
:
ImageOps
.
equalize
(
img
),
"invert"
:
lambda
img
,
_
:
ImageOps
.
invert
(
img
)
}
def
get_operation
(
self
,
method
,
magnitude_idx
):
magnitude
=
self
.
ranges
[
method
][
magnitude_idx
]
return
lambda
img
:
self
.
operations
[
method
](
img
,
magnitude
)
imagenet/image_classification/dataloaders.py
0 → 100644
View file @
9dc09f0d
# Copyright (c) 2018-2019, NVIDIA CORPORATION
# Copyright (c) 2017- Facebook, Inc
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import
os
import
torch
import
numpy
as
np
import
torchvision.datasets
as
datasets
import
torchvision.transforms
as
transforms
from
PIL
import
Image
from
functools
import
partial
from
image_classification.autoaugment
import
AutoaugmentImageNetPolicy
DATA_BACKEND_CHOICES
=
[
"pytorch"
,
"syntetic"
]
try
:
from
nvidia.dali.plugin.pytorch
import
DALIClassificationIterator
from
nvidia.dali.pipeline
import
Pipeline
import
nvidia.dali.ops
as
ops
import
nvidia.dali.types
as
types
DATA_BACKEND_CHOICES
.
append
(
"dali-gpu"
)
DATA_BACKEND_CHOICES
.
append
(
"dali-cpu"
)
except
ImportError
:
print
(
"Please install DALI from https://www.github.com/NVIDIA/DALI to run this example."
)
def
load_jpeg_from_file
(
path
,
cuda
=
True
):
img_transforms
=
transforms
.
Compose
(
[
transforms
.
Resize
(
256
),
transforms
.
CenterCrop
(
224
),
transforms
.
ToTensor
()]
)
img
=
img_transforms
(
Image
.
open
(
path
))
with
torch
.
no_grad
():
# mean and std are not multiplied by 255 as they are in training script
# torch dataloader reads data into bytes whereas loading directly
# through PIL creates a tensor with floats in [0,1] range
mean
=
torch
.
tensor
([
0.485
,
0.456
,
0.406
]).
view
(
1
,
3
,
1
,
1
)
std
=
torch
.
tensor
([
0.229
,
0.224
,
0.225
]).
view
(
1
,
3
,
1
,
1
)
if
cuda
:
mean
=
mean
.
cuda
()
std
=
std
.
cuda
()
img
=
img
.
cuda
()
img
=
img
.
float
()
input
=
img
.
unsqueeze
(
0
).
sub_
(
mean
).
div_
(
std
)
return
input
class
HybridTrainPipe
(
Pipeline
):
def
__init__
(
self
,
batch_size
,
num_threads
,
device_id
,
data_dir
,
interpolation
,
crop
,
dali_cpu
=
False
,
):
super
(
HybridTrainPipe
,
self
).
__init__
(
batch_size
,
num_threads
,
device_id
,
seed
=
12
+
device_id
)
interpolation
=
{
"bicubic"
:
types
.
INTERP_CUBIC
,
"bilinear"
:
types
.
INTERP_LINEAR
,
"triangular"
:
types
.
INTERP_TRIANGULAR
,
}[
interpolation
]
if
torch
.
distributed
.
is_initialized
():
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
else
:
rank
=
0
world_size
=
1
self
.
input
=
ops
.
FileReader
(
file_root
=
data_dir
,
shard_id
=
rank
,
num_shards
=
world_size
,
random_shuffle
=
True
,
pad_last_batch
=
True
,
)
if
dali_cpu
:
dali_device
=
"cpu"
self
.
decode
=
ops
.
ImageDecoder
(
device
=
dali_device
,
output_type
=
types
.
RGB
)
else
:
dali_device
=
"gpu"
# This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet
# without additional reallocations
self
.
decode
=
ops
.
ImageDecoder
(
device
=
"mixed"
,
output_type
=
types
.
RGB
,
device_memory_padding
=
211025920
,
host_memory_padding
=
140544512
,
)
self
.
res
=
ops
.
RandomResizedCrop
(
device
=
dali_device
,
size
=
[
crop
,
crop
],
interp_type
=
interpolation
,
random_aspect_ratio
=
[
0.75
,
4.0
/
3.0
],
random_area
=
[
0.08
,
1.0
],
num_attempts
=
100
,
)
self
.
cmnp
=
ops
.
CropMirrorNormalize
(
device
=
"gpu"
,
dtype
=
types
.
FLOAT
,
output_layout
=
types
.
NCHW
,
crop
=
(
crop
,
crop
),
mean
=
[
0.485
*
255
,
0.456
*
255
,
0.406
*
255
],
std
=
[
0.229
*
255
,
0.224
*
255
,
0.225
*
255
],
)
self
.
coin
=
ops
.
CoinFlip
(
probability
=
0.5
)
def
define_graph
(
self
):
rng
=
self
.
coin
()
self
.
jpegs
,
self
.
labels
=
self
.
input
(
name
=
"Reader"
)
images
=
self
.
decode
(
self
.
jpegs
)
images
=
self
.
res
(
images
)
output
=
self
.
cmnp
(
images
.
gpu
(),
mirror
=
rng
)
return
[
output
,
self
.
labels
]
class
HybridValPipe
(
Pipeline
):
def
__init__
(
self
,
batch_size
,
num_threads
,
device_id
,
data_dir
,
interpolation
,
crop
,
size
):
super
(
HybridValPipe
,
self
).
__init__
(
batch_size
,
num_threads
,
device_id
,
seed
=
12
+
device_id
)
interpolation
=
{
"bicubic"
:
types
.
INTERP_CUBIC
,
"bilinear"
:
types
.
INTERP_LINEAR
,
"triangular"
:
types
.
INTERP_TRIANGULAR
,
}[
interpolation
]
if
torch
.
distributed
.
is_initialized
():
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
else
:
rank
=
0
world_size
=
1
self
.
input
=
ops
.
FileReader
(
file_root
=
data_dir
,
shard_id
=
rank
,
num_shards
=
world_size
,
random_shuffle
=
False
,
pad_last_batch
=
True
,
)
self
.
decode
=
ops
.
ImageDecoder
(
device
=
"mixed"
,
output_type
=
types
.
RGB
)
self
.
res
=
ops
.
Resize
(
device
=
"gpu"
,
resize_shorter
=
size
,
interp_type
=
interpolation
)
self
.
cmnp
=
ops
.
CropMirrorNormalize
(
device
=
"gpu"
,
dtype
=
types
.
FLOAT
,
output_layout
=
types
.
NCHW
,
crop
=
(
crop
,
crop
),
mean
=
[
0.485
*
255
,
0.456
*
255
,
0.406
*
255
],
std
=
[
0.229
*
255
,
0.224
*
255
,
0.225
*
255
],
)
def
define_graph
(
self
):
self
.
jpegs
,
self
.
labels
=
self
.
input
(
name
=
"Reader"
)
images
=
self
.
decode
(
self
.
jpegs
)
images
=
self
.
res
(
images
)
output
=
self
.
cmnp
(
images
)
return
[
output
,
self
.
labels
]
class
DALIWrapper
(
object
):
def
gen_wrapper
(
dalipipeline
,
num_classes
,
one_hot
,
memory_format
):
for
data
in
dalipipeline
:
input
=
data
[
0
][
"data"
].
contiguous
(
memory_format
=
memory_format
)
target
=
torch
.
reshape
(
data
[
0
][
"label"
],
[
-
1
]).
cuda
().
long
()
if
one_hot
:
target
=
expand
(
num_classes
,
torch
.
float
,
target
)
yield
input
,
target
dalipipeline
.
reset
()
def
__init__
(
self
,
dalipipeline
,
num_classes
,
one_hot
,
memory_format
):
self
.
dalipipeline
=
dalipipeline
self
.
num_classes
=
num_classes
self
.
one_hot
=
one_hot
self
.
memory_format
=
memory_format
def
__iter__
(
self
):
return
DALIWrapper
.
gen_wrapper
(
self
.
dalipipeline
,
self
.
num_classes
,
self
.
one_hot
,
self
.
memory_format
)
def
get_dali_train_loader
(
dali_cpu
=
False
):
def
gdtl
(
data_path
,
image_size
,
batch_size
,
num_classes
,
one_hot
,
interpolation
=
"bilinear"
,
augmentation
=
None
,
start_epoch
=
0
,
workers
=
5
,
_worker_init_fn
=
None
,
memory_format
=
torch
.
contiguous_format
,
):
if
torch
.
distributed
.
is_initialized
():
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
else
:
rank
=
0
world_size
=
1
traindir
=
os
.
path
.
join
(
data_path
,
"train"
)
if
augmentation
is
not
None
:
raise
NotImplementedError
(
f
"Augmentation
{
augmentation
}
for dali loader is not supported"
)
pipe
=
HybridTrainPipe
(
batch_size
=
batch_size
,
num_threads
=
workers
,
device_id
=
rank
%
torch
.
cuda
.
device_count
(),
data_dir
=
traindir
,
interpolation
=
interpolation
,
crop
=
image_size
,
dali_cpu
=
dali_cpu
,
)
pipe
.
build
()
train_loader
=
DALIClassificationIterator
(
pipe
,
reader_name
=
"Reader"
,
fill_last_batch
=
False
)
return
(
DALIWrapper
(
train_loader
,
num_classes
,
one_hot
,
memory_format
),
int
(
pipe
.
epoch_size
(
"Reader"
)
/
(
world_size
*
batch_size
)),
)
return
gdtl
def
get_dali_val_loader
():
def
gdvl
(
data_path
,
image_size
,
batch_size
,
num_classes
,
one_hot
,
interpolation
=
"bilinear"
,
crop_padding
=
32
,
workers
=
5
,
_worker_init_fn
=
None
,
memory_format
=
torch
.
contiguous_format
,
):
if
torch
.
distributed
.
is_initialized
():
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
else
:
rank
=
0
world_size
=
1
valdir
=
os
.
path
.
join
(
data_path
,
"val"
)
pipe
=
HybridValPipe
(
batch_size
=
batch_size
,
num_threads
=
workers
,
device_id
=
rank
%
torch
.
cuda
.
device_count
(),
data_dir
=
valdir
,
interpolation
=
interpolation
,
crop
=
image_size
,
size
=
image_size
+
crop_padding
,
)
pipe
.
build
()
val_loader
=
DALIClassificationIterator
(
pipe
,
reader_name
=
"Reader"
,
fill_last_batch
=
False
)
return
(
DALIWrapper
(
val_loader
,
num_classes
,
one_hot
,
memory_format
),
int
(
pipe
.
epoch_size
(
"Reader"
)
/
(
world_size
*
batch_size
)),
)
return
gdvl
def
fast_collate
(
memory_format
,
batch
):
imgs
=
[
img
[
0
]
for
img
in
batch
]
targets
=
torch
.
tensor
([
target
[
1
]
for
target
in
batch
],
dtype
=
torch
.
int64
)
w
=
imgs
[
0
].
size
[
0
]
h
=
imgs
[
0
].
size
[
1
]
tensor
=
torch
.
zeros
((
len
(
imgs
),
3
,
h
,
w
),
dtype
=
torch
.
uint8
).
contiguous
(
memory_format
=
memory_format
)
for
i
,
img
in
enumerate
(
imgs
):
nump_array
=
np
.
asarray
(
img
,
dtype
=
np
.
uint8
)
if
nump_array
.
ndim
<
3
:
nump_array
=
np
.
expand_dims
(
nump_array
,
axis
=-
1
)
nump_array
=
np
.
rollaxis
(
nump_array
,
2
)
tensor
[
i
]
+=
torch
.
from_numpy
(
nump_array
.
copy
())
return
tensor
,
targets
def
expand
(
num_classes
,
dtype
,
tensor
):
e
=
torch
.
zeros
(
tensor
.
size
(
0
),
num_classes
,
dtype
=
dtype
,
device
=
torch
.
device
(
"cuda"
)
)
e
=
e
.
scatter
(
1
,
tensor
.
unsqueeze
(
1
),
1.0
)
return
e
class
PrefetchedWrapper
(
object
):
def
prefetched_loader
(
loader
,
num_classes
,
one_hot
):
mean
=
(
torch
.
tensor
([
0.485
*
255
,
0.456
*
255
,
0.406
*
255
])
.
cuda
()
.
view
(
1
,
3
,
1
,
1
)
)
std
=
(
torch
.
tensor
([
0.229
*
255
,
0.224
*
255
,
0.225
*
255
])
.
cuda
()
.
view
(
1
,
3
,
1
,
1
)
)
stream
=
torch
.
cuda
.
Stream
()
first
=
True
for
next_input
,
next_target
in
loader
:
with
torch
.
cuda
.
stream
(
stream
):
next_input
=
next_input
.
cuda
(
non_blocking
=
True
)
next_target
=
next_target
.
cuda
(
non_blocking
=
True
)
next_input
=
next_input
.
float
()
if
one_hot
:
next_target
=
expand
(
num_classes
,
torch
.
float
,
next_target
)
next_input
=
next_input
.
sub_
(
mean
).
div_
(
std
)
if
not
first
:
yield
input
,
target
else
:
first
=
False
torch
.
cuda
.
current_stream
().
wait_stream
(
stream
)
input
=
next_input
target
=
next_target
yield
input
,
target
def
__init__
(
self
,
dataloader
,
start_epoch
,
num_classes
,
one_hot
):
self
.
dataloader
=
dataloader
self
.
epoch
=
start_epoch
self
.
one_hot
=
one_hot
self
.
num_classes
=
num_classes
def
__iter__
(
self
):
if
self
.
dataloader
.
sampler
is
not
None
and
isinstance
(
self
.
dataloader
.
sampler
,
torch
.
utils
.
data
.
distributed
.
DistributedSampler
):
self
.
dataloader
.
sampler
.
set_epoch
(
self
.
epoch
)
self
.
epoch
+=
1
return
PrefetchedWrapper
.
prefetched_loader
(
self
.
dataloader
,
self
.
num_classes
,
self
.
one_hot
)
def
__len__
(
self
):
return
len
(
self
.
dataloader
)
def
get_pytorch_train_loader
(
data_path
,
image_size
,
batch_size
,
num_classes
,
one_hot
,
interpolation
=
"bilinear"
,
augmentation
=
None
,
start_epoch
=
0
,
workers
=
5
,
_worker_init_fn
=
None
,
memory_format
=
torch
.
contiguous_format
,
):
interpolation
=
{
"bicubic"
:
Image
.
BICUBIC
,
"bilinear"
:
Image
.
BILINEAR
}[
interpolation
]