Commit c14d4fe4 authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

removing stale imports and forcing averaging of validation errors across ranks

parent ef675861
Loading
Loading
Loading
Loading
+0 −8
Original line number Diff line number Diff line
# Not supporting for __all__
from . import inputs
#from .inputs import *
from . import io_utils
#from .io_utils import *
from . import network
#from .network import *
from . import runtime
#from .runtime import *
from . import ops
#from .ops import *
from . import network_utils
#from .network_utils import *
from . import optimizers
from . import mp_wrapper
from . import lr_policies
from . import automatic_loss_scaler
#__all__ = ['inputs', 'io_utils', 'network', 'runtime', 'network_utils', 'ops']
+1 −1
Original line number Diff line number Diff line
@@ -820,7 +820,7 @@ def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
            loss_label= 'ABS_DIFF'
            errors = tf.losses.absolute_difference(tf.cast(labels, tf.float32), tf.cast(n_net.model_output, tf.float32), reduction=tf.losses.Reduction.MEAN)
        errors = tf.expand_dims(errors,axis=0)
        error_averaging = hvd.allreduce(errors)
        error_averaging = hvd.allreduce(errors, average=True)

        if num_batches is not None:
            num_samples = num_batches