Commit ce29d4d5 authored by Mohammad's avatar Mohammad
Browse files

working on refactoring text generation

parent a0bcee94
Loading
Loading
Loading
Loading
+174 −198

File changed.

Preview size limit exceeded, changes collapsed.

+1 −23
Original line number Diff line number Diff line
@@ -357,29 +357,7 @@ def _add_gpt2_args(parser):



def add_text_generate_args(parser):
    """Text generate arguments."""

    group = parser.add_argument_group('Text generation', 'configurations')
    group.add_argument("--temperature", type=float, default=1.0)
    group.add_argument("--greedy", action='store_true', default=False)
    group.add_argument("--top_p", type=float, default=0.0)
    group.add_argument("--top_k", type=int, default=0)
    group.add_argument("--out-seq-length", type=int, default=1024)
    group.add_argument("--sample-input-file", type=str, default="",
                      help='get input from file instead of interactive mode, '
                           'each line is an input' )
    group.add_argument("--sample-output-file", type=str, default="",
                      help='output file got from --sample-input-file')
    group.add_argument("--num-samples", type=int, default=0,
                       help='number of samples to generate unconditionally, '
                       'defaults to 0 and interactive conditional sampling')
    group.add_argument("--genfile", type=str,
                       help='output file when generating unconditionally')
    group.add_argument("--recompute", action='store_true',
                       help='during generation recompute all attention '
                       'instead of using previously computed keys/values.')
    return parser



def add_data_args_(parser):
+1 −2
Original line number Diff line number Diff line
@@ -137,8 +137,7 @@ class BertModel(MegatronModule):
            self._binary_head_key = 'binary_head'


    def forward(self, input_ids, attention_mask,
                tokentype_ids=None):
    def forward(self, input_ids, attention_mask, tokentype_ids=None):

        extended_attention_mask = bert_extended_attention_mask(
            attention_mask, next(self.language_model.parameters()).dtype)
+6 −2
Original line number Diff line number Diff line
@@ -51,7 +51,8 @@ class GPT2Model(MegatronModule):


    def forward(self, input_ids, position_ids, attention_mask,
                tokentype_ids=None, layer_past=None, get_key_value=False):
                tokentype_ids=None, layer_past=None, get_key_value=False,
                forward_method_parallel_output=None):

        # Language model.
        lm_output = self.language_model(input_ids,
@@ -65,10 +66,13 @@ class GPT2Model(MegatronModule):
            lm_output, presents = lm_output

        # Output.
        parallel_output = self.parallel_output
        if forward_method_parallel_output is not None:
            parallel_output = forward_method_parallel_output
        output = parallel_lm_logits(
            lm_output,
            self.language_model.embedding.word_embeddings.weight,
            self.parallel_output)
            parallel_output)

        if get_key_value:
            output = [output, presents]
+7 −0
Original line number Diff line number Diff line
@@ -91,6 +91,10 @@ class AbstractTokenizer(ABC):
    def tokenize(self, text):
        pass

    def detokenize(self, token_ids):
        raise NotImplementedError('detokenizer is not implemented for {} '
                                  'tokenizer'.format(self.name))

    @property
    def cls(self):
        raise NotImplementedError('CLS is not provided for {} '
@@ -190,6 +194,9 @@ class _GPT2BPETokenizer(AbstractTokenizer):
    def tokenize(self, text):
        return self.tokenizer.encode(text)

    def detokenize(self, token_ids):
        return self.tokenizer.decode(token_ids)

    @property
    def eod(self):
        return self.eod_id