Commit cfc6924b authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

fixed a bug on fp16 while generating samples

parent 4ad5717e
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -98,6 +98,10 @@ def get_batch(context_tokens, args):
        args.reset_attention_mask,
        False)

    # Fp16 conversion.
    if args.fp16:
        attention_mask = attention_mask.half()

    return tokens, attention_mask, position_ids

def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):