Commit c14d4fe4 authored by Laanait, Nouamane's avatar Laanait, Nouamane

removing stale imports and forcing averaging of validation errors across ranks

parent ef675861
# Not supporting for __all__
from . import inputs
#from .inputs import *
from . import io_utils
#from .io_utils import *
from . import network
#from .network import *
from . import runtime
#from .runtime import *
from . import ops
#from .ops import *
from . import network_utils
#from .network_utils import *
from . import optimizers
from . import mp_wrapper
from . import lr_policies
from . import automatic_loss_scaler
#__all__ = ['inputs', 'io_utils', 'network', 'runtime', 'network_utils', 'ops']
......@@ -820,7 +820,7 @@ def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
loss_label= 'ABS_DIFF'
errors = tf.losses.absolute_difference(tf.cast(labels, tf.float32), tf.cast(n_net.model_output, tf.float32), reduction=tf.losses.Reduction.MEAN)
errors = tf.expand_dims(errors,axis=0)
error_averaging = hvd.allreduce(errors)
error_averaging = hvd.allreduce(errors, average=True)
if num_batches is not None:
num_samples = num_batches
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment