Commit e3e5ea89 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Compute tensor chunk size more cleanly, and add assertion for global batch size

parent 27fc4689
Loading
Loading
Loading
Loading
+5 −1
Original line number Diff line number Diff line
@@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={},
    model_parallel_size = args.pipeline_model_parallel_size * \
                          args.tensor_model_parallel_size
    assert args.world_size % model_parallel_size == 0, 'world size is not'\
        ' divisible by tensor parallel size ({}) times pipeline paralle ' \
        ' divisible by tensor parallel size ({}) times pipeline parallel ' \
        'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
                           args.pipeline_model_parallel_size)
    args.data_parallel_size = args.world_size // model_parallel_size
@@ -116,6 +116,10 @@ def parse_args(extra_args_provider=None, defaults={},
            print('setting global batch size to {}'.format(
                args.global_batch_size), flush=True)
    assert args.global_batch_size > 0
    if args.virtual_pipeline_model_parallel_size is not None:
        assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \
            'global batch size is not divisible by pipeline parallel size when '\
            'using interleaved schedule'

    # Parameters dtype.
    args.params_dtype = torch.float
+4 −3
Original line number Diff line number Diff line
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import reduce
import operator
import torch

from megatron import get_args
@@ -30,8 +32,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
    tensor_recv_next = None
    tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
    if args.scatter_gather_tensors_in_pipeline:
        tensor_chunk_shape = (
            args.seq_length * args.micro_batch_size * args.hidden_size) // \
        tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
            mpu.get_tensor_model_parallel_world_size()
    else:
        tensor_chunk_shape = tensor_shape