Loading generate_samples.py +4 −0 Original line number Diff line number Diff line Loading @@ -98,6 +98,10 @@ def get_batch(context_tokens, args): args.reset_attention_mask, False) # Fp16 conversion. if args.fp16: attention_mask = attention_mask.half() return tokens, attention_mask, position_ids def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): Loading Loading
generate_samples.py +4 −0 Original line number Diff line number Diff line Loading @@ -98,6 +98,10 @@ def get_batch(context_tokens, args): args.reset_attention_mask, False) # Fp16 conversion. if args.fp16: attention_mask = attention_mask.half() return tokens, attention_mask, position_ids def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): Loading