Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Tsaris, Aristeidis (aris)
pytorch_tutorial
Commits
558ffac0
Commit
558ffac0
authored
Oct 08, 2021
by
Aristeidis Tsaris
Browse files
adding nvtx
parent
6aff9d3d
Changes
2
Hide whitespace changes
Inline
Side-by-side
ascent/sub_prof.lsf
View file @
558ffac0
...
...
@@ -4,7 +4,7 @@
#BSUB -J sc21
#BSUB -o logs/sc21.o%J
#BSUB -W 0:30
#BSUB -nnodes
1
#BSUB -nnodes
2
#BSUB -alloc_flags "nvme smt4"
####BSUB -N
# End LSF directives and begin shell commands
...
...
imagenet/image_classification/training.py
View file @
558ffac0
...
...
@@ -44,6 +44,8 @@ from .models.common import EMA
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.cuda.amp
import
autocast
import
torch.cuda.nvtx
as
nvtx
ACC_METADATA
=
{
"unit"
:
"%"
,
"format"
:
":.2f"
}
IPS_METADATA
=
{
"unit"
:
"img/s"
,
"format"
:
":.2f"
}
TIME_METADATA
=
{
"unit"
:
"s"
,
"format"
:
":.5f"
}
...
...
@@ -52,6 +54,16 @@ LOSS_METADATA = {"format": ":.5f"}
def
count_parameters
(
model
):
return
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
def
nvtx_range_push
(
name
,
enabled
):
if
enabled
:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
nvtx
.
range_push
(
name
)
def
nvtx_range_pop
(
enabled
):
if
enabled
:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
nvtx
.
range_pop
()
class
ModelAndLoss
(
nn
.
Module
):
def
__init__
(
...
...
@@ -87,7 +99,7 @@ class ModelAndLoss(nn.Module):
if
self
.
noDDP
:
self
.
model
=
self
.
model
else
:
self
.
model
=
DDP
(
self
.
model
,
device_ids
=
[
gpu_id
],
output_device
=
gpu_id
)
self
.
model
=
DDP
(
self
.
model
,
device_ids
=
[
gpu_id
],
output_device
=
gpu_id
,
bucket_cap_mb
=
1
)
def
load_model_state
(
self
,
state
):
if
not
state
is
None
:
...
...
@@ -202,7 +214,6 @@ def get_train_step(
if
torch
.
distributed
.
is_initialized
()
and
not
noDDP
:
reduced_loss
=
utils
.
reduce_tensor
(
loss
.
data
)
else
:
print
(
"I am here"
,
flush
=
True
)
reduced_loss
=
loss
.
data
scaler
.
scale
(
loss
).
backward
()
...
...
@@ -283,11 +294,14 @@ def train(
optimizer
.
zero_grad
()
torch
.
cuda
.
synchronize
()
data_iter
=
enumerate
(
train_loader
)
if
logger
is
not
None
:
data_iter
=
logger
.
iteration_generator_wrapper
(
data_iter
,
mode
=
'train'
)
for
i
,
(
input
,
target
)
in
data_iter
:
nvtx_range_push
(
'step {}'
.
format
(
i
),
True
)
bs
=
input
.
size
(
0
)
lr_scheduler
(
optimizer
,
i
,
epoch
)
data_time
=
time
.
time
()
-
end
...
...
@@ -315,6 +329,9 @@ def train(
interrupted
=
True
break
torch
.
cuda
.
synchronize
()
nvtx_range_pop
(
True
)
return
interrupted
...
...
@@ -539,23 +556,26 @@ def train_loop(
if
logger
is
not
None
:
logger
.
start_epoch
()
if
not
skip_training
:
interrupted
=
train
(
train_loader
,
model_and_loss
,
optimizer
,
scaler
,
lr_scheduler
,
logger
,
epoch
,
steps_per_epoch
,
timeout_handler
,
ema
=
ema
,
use_amp
=
use_amp
,
prof
=
prof
,
register_metrics
=
epoch
==
start_epoch
,
batch_size_multiplier
=
batch_size_multiplier
,
noDDP
=
noDDP
,
)
nvtx_range_push
(
'epoch {}'
.
format
(
epoch
),
True
)
with
torch
.
autograd
.
profiler
.
emit_nvtx
(
enabled
=
False
):
interrupted
=
train
(
train_loader
,
model_and_loss
,
optimizer
,
scaler
,
lr_scheduler
,
logger
,
epoch
,
steps_per_epoch
,
timeout_handler
,
ema
=
ema
,
use_amp
=
use_amp
,
prof
=
prof
,
register_metrics
=
epoch
==
start_epoch
,
batch_size_multiplier
=
batch_size_multiplier
,
noDDP
=
noDDP
,
)
nvtx_range_pop
(
True
)
if
not
skip_validation
:
prec1
,
nimg
=
validate
(
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment