Loading generate_samples.py +8 −0 Original line number Diff line number Diff line Loading @@ -362,6 +362,12 @@ def switch(val1, val2, boolean): return (1-boolean)*val1 + boolean*val2 def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None): if isinstance(model, DDP): model = model.module if isinstance(model, FP16_Module): model = model.module original_output_parallel = model.parallel_output model.parallel_output = False model.eval() with torch.no_grad(): context_length = context_lengths.min().item() Loading Loading @@ -404,6 +410,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask if args.greedy: prev = torch.argmax(logits, dim=-1).view(-1) else: logits = logits.float() logits /= args.temperature logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) log_probs = F.softmax(logits, dim=-1) Loading @@ -427,6 +434,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask yield tokens, lengths if done: break model.parallel_output = original_output_parallel def prepare_tokenizer(args): Loading Loading
generate_samples.py +8 −0 Original line number Diff line number Diff line Loading @@ -362,6 +362,12 @@ def switch(val1, val2, boolean): return (1-boolean)*val1 + boolean*val2 def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None): if isinstance(model, DDP): model = model.module if isinstance(model, FP16_Module): model = model.module original_output_parallel = model.parallel_output model.parallel_output = False model.eval() with torch.no_grad(): context_length = context_lengths.min().item() Loading Loading @@ -404,6 +410,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask if args.greedy: prev = torch.argmax(logits, dim=-1).view(-1) else: logits = logits.float() logits /= args.temperature logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) log_probs = F.softmax(logits, dim=-1) Loading @@ -427,6 +434,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask yield tokens, lengths if done: break model.parallel_output = original_output_parallel def prepare_tokenizer(args): Loading