Loading generate_samples.py +8 −7 Original line number Diff line number Diff line Loading @@ -366,12 +366,13 @@ def switch(val1, val2, boolean): return (1-boolean)*val1 + boolean*val2 def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None): if isinstance(model, DDP): model = model.module if isinstance(model, FP16_Module): model = model.module original_output_parallel = model.parallel_output model.parallel_output = False actual_model = model if isinstance(actual_model, DDP): actual_model = actual_model.module if isinstance(actual_model, FP16_Module): actual_model = actual_model.module original_output_parallel = actual_model.parallel_output actual_model.parallel_output = False model.eval() with torch.no_grad(): context_length = context_lengths.min().item() Loading Loading @@ -438,7 +439,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask yield tokens, lengths if done: break model.parallel_output = original_output_parallel actual_model.parallel_output = original_output_parallel def prepare_tokenizer(args): Loading Loading
generate_samples.py +8 −7 Original line number Diff line number Diff line Loading @@ -366,12 +366,13 @@ def switch(val1, val2, boolean): return (1-boolean)*val1 + boolean*val2 def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None): if isinstance(model, DDP): model = model.module if isinstance(model, FP16_Module): model = model.module original_output_parallel = model.parallel_output model.parallel_output = False actual_model = model if isinstance(actual_model, DDP): actual_model = actual_model.module if isinstance(actual_model, FP16_Module): actual_model = actual_model.module original_output_parallel = actual_model.parallel_output actual_model.parallel_output = False model.eval() with torch.no_grad(): context_length = context_lengths.min().item() Loading Loading @@ -438,7 +439,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask yield tokens, lengths if done: break model.parallel_output = original_output_parallel actual_model.parallel_output = original_output_parallel def prepare_tokenizer(args): Loading