Commit a87777bf authored by zihanl's avatar zihanl
Browse files

delete finetune part

parent 5f4e63fc
Loading
Loading
Loading
Loading

tasks/knwl_dialo/data.py

deleted100644 → 0
+0 −243
Original line number Diff line number Diff line

"""Build Dataset for Controllable Coversational Model"""

import os
import torch
import numpy as np

from megatron import get_tokenizer
from megatron import print_rank_0


def read_data_for_finetuning(tokenizer, data_path, module):
    """
    Data Format: topic \t dialog context \t knowledge \t response.
    """
    
    data_list = []
    with open(data_path, "r") as f:
        for i, line in enumerate(f):
            line = line.rstrip()
            splits = line.split("\t")
            assert len(splits) == 4

            topic = splits[0].split(" [CTRL] ")[0]
            dialog_context = splits[1]
            knowledge = splits[2]
            response = splits[3]

            turns = dialog_context.split(" [SEP] ")
            turns = turns[-3:]

            if module == "response":
                # input_ids
                input_ids = tokenizer.tokenize("( " + topic + " )")
                if knowledge != "no_passages_used":
                    input_ids.extend(tokenizer.tokenize("( " + knowledge + " )")[:256])
                
                for turn in turns:
                    turn = "<< " + turn + " >>"
                    input_ids.extend(tokenizer.tokenize(turn))
                input_ids.extend(tokenizer.tokenize(":"))

                # output_ids
                output_ids = tokenizer.tokenize(response)

                data_list.append({"input_ids": input_ids, "output_ids": output_ids})
                
            elif module == "knowledge":
                # skip example without knowledge sentences
                if knowledge == "no_passages_used":
                    continue

                input_ids = []
                input_ids.extend(tokenizer.tokenize("( " + topic + " )"))
                
                for turn in turns:
                    turn = "<< " + turn + " >>"
                    input_ids.extend(tokenizer.tokenize(turn))
                input_ids.extend(tokenizer.tokenize(":"))

                output_ids = tokenizer.tokenize(knowledge)

                data_list.append({"input_ids": input_ids, "output_ids": output_ids})

            else:
                raise ValueError("Please input a correct module name! " \
                                 "(either dialog or cnotrol))")
    
    return data_list


def read_data_for_prompting(tokenizer, test_data_path, prompt_file, 
                            module, num_prompt_examples, dynamic_prompt):
    
    # get prompts
    if dynamic_prompt:
        import json
        prompt_examples_dict = {}
        with open(prompt_file, "r") as f:
            for i, line in enumerate(f):
                line = line.strip()
                line_dict = json.loads(line)
                key = list(line_dict.keys())[0]
                
                if key not in prompt_examples_dict:
                    prompt_examples = line_dict[key]
                    prompt_examples = prompt_examples[:num_prompt_examples]
                    prompt = ""
                    for instance in prompt_examples:
                        instance = instance.strip()
                        prompt += instance + " \n"

                    prompt_examples_dict[topic] = prompt

    else:
        with open(prompt_file, "r") as f:
            prompt_examples = f.readlines()
    
            prompt_examples = prompt_examples[:num_prompt_examples]
            prompt = ""
            for instance in prompt_examples:
                instance = instance.strip()
                prompt += instance + " \n"

    data_list = []
    with open(test_data_path, "r") as f:
        for i, line in enumerate(f):
            line = line.strip()
            splits = line.split("\t")

            topic = splits[0].split(" [CTRL] ")[0]
            turns = splits[1].split(" [SEP] ")[-3:]
            last_turn = turns[-1]
            ctrl_sent = splits[2]
            response = splits[3]

            if dynamic_prompt:
                prompt = prompt_examples_dict[topic]

            if module == "response":
                # input seq
                input_seq = prompt

                input_seq += "Topic: " + topic + ". "
                input_seq += "User says: " + last_turn + " "
                input_seq += "We know that: " + ctrl_sent + " "
                input_seq += "System replies:"

                # output seq
                output_seq = response

                input_ids = tokenizer.tokenize(input_seq)
                output_ids = tokenizer.tokenize(output_seq)
                data_list.append({"input_ids": input_ids, "output_ids": output_ids})

            elif module == "knowledge":
                # input seq
                input_seq = prompt
                input_seq += "( " + last_turn + " ) " + topic + " =>"

                # output seq
                output_seq = ctrl_sent

                input_ids = tokenizer.tokenize(input_seq)
                output_ids = tokenizer.tokenize(output_seq)
                data_list.append({"input_ids": input_ids, "output_ids": output_ids})

            else:
                raise ValueError("Please input a correct module name! " \
                                 "(either dialog or cnotrol))")

    return data_list


def data_shuffle(data, seed):
    # set random seed to make the shuffling reproducible
    np.random.seed(seed)
    np.random.shuffle(data)
    return data


class KnwlDialoDataset(torch.utils.data.Dataset):

    def __init__(self, data, max_seq_len, pad_id, eod_id):
        # need to deal with padding, label masking
        self.data = data
        self.max_seq_len = max_seq_len
        self.pad_id = pad_id
        self.eod_id = eod_id

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data_dict = self.data[idx]
        input_ids, output_ids = data_dict["input_ids"], data_dict["output_ids"]
        
        text = input_ids + output_ids + [self.eod_id]
        loss_mask = [0]*(len(input_ids)-1) + [1]*(len(output_ids)+1)

        text_len = len(text)
        if text_len > self.max_seq_len+1:
            text = text[:self.max_seq_len+1]
            loss_mask = loss_mask[:self.max_seq_len]
        else:
            text += [self.pad_id] * (self.max_seq_len+1 - text_len)
            loss_mask += [0] * (self.max_seq_len+1 - text_len)

        return {"text": np.array(text, dtype=np.int64), \
                "loss_mask": np.array(loss_mask, dtype=np.int64)}


def build_train_valid_datasets(train_data_path, valid_data_path, module,
                               max_seq_len, seed):
    """Build train, valid, and test datasets."""

    tokenizer = get_tokenizer()
    train_data_list = read_data_for_finetuning(tokenizer, train_data_path, module)
    valid_data_list = read_data_for_finetuning(tokenizer, valid_data_path, module)

    # shuffle the training data
    train_data_list = data_shuffle(train_data_list, seed)

    # build train, valid datasets
    train_dataset = KnwlDialoDataset(train_data_list, 
                                     max_seq_len, 
                                     pad_id=tokenizer.pad_id, 
                                     eod_id=tokenizer.eod_id)

    valid_dataset = KnwlDialoDataset(valid_data_list, 
                                     max_seq_len, 
                                     pad_id=tokenizer.pad_id, 
                                     eod_id=tokenizer.eod_id)

    return train_dataset, valid_dataset


def build_test_dataset(test_data_path, module, max_seq_len):
    tokenizer = get_tokenizer()

    test_data_list = read_data_for_finetuning(tokenizer, test_data_path, module)

    test_dataset = KnwlDialoDataset(test_data_list, 
                                    max_seq_len, 
                                    pad_id=tokenizer.pad_id, 
                                    eod_id=tokenizer.eod_id)

    return test_dataset


def build_test_dataset_for_prompting(test_data_path, prompt_file, module, max_seq_len, 
                                     num_prompt_examples, dynamic_prompt):
    tokenizer = get_tokenizer()

    test_data_list = read_data_for_prompting(tokenizer, test_data_path, prompt_file, module, \
                                             num_prompt_examples, dynamic_prompt)

    test_dataset = KnwlDialoDataset(test_data_list,
                                    max_seq_len,
                                    pad_id=tokenizer.pad_id, 
                                    eod_id=tokenizer.eod_id)

    return test_dataset

tasks/knwl_dialo/finetune.py

deleted100644 → 0
+0 −210
Original line number Diff line number Diff line

"""Finetuning a pretrained language model for knowledge/response generation"""

import torch
from functools import partial
from megatron import mpu
from megatron import get_args
from megatron import get_timers
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.model import GPTModel
from megatron.training import evaluate_and_print_results
from megatron.training import get_model
from megatron.utils import average_losses_across_data_parallel_group
from megatron.initialize import initialize_megatron
from tasks.finetune_utils import finetune
from tasks.knwl_dialo.data import build_train_valid_datasets
from tasks.knwl_dialo.utils import get_ltor_attention_masks_and_position_ids
from tasks.knwl_dialo.utils import get_token_stream


def model_provider(pre_process=True, post_process=True):
    """Build the model."""

    print_rank_0('building GPT model ...')
    model = GPTModel(
        num_tokentypes=0,
        parallel_output=True,
        pre_process=pre_process,
        post_process=post_process
    )
    return model


def train_valid_datasets_provider():
    """Build train, valid, and test datasets for dialog/control module"""
    args = get_args()

    print_rank_0('> building train, validation, and test datasets for %s module ...' % args.module)
    
    train_ds, valid_ds = build_train_valid_datasets(
        train_data_path=args.train_data_path,
        valid_data_path=args.test_data_path,
        module=args.module,
        max_seq_len=args.seq_length,
        seed=args.seed)
        
    print_rank_0("> finished creating datasets for %s module ..." % args.module)
    print_rank_0('> Train size: %d' % len(train_ds))
    print_rank_0('> Validation size: %d' % len(valid_ds))

    args.eval_interval = len(train_ds) // args.global_batch_size
    print_rank_0('> evaluation interval: %d' % args.eval_interval)

    args.eval_iters = len(valid_ds) // args.global_batch_size
    print_rank_0('> evaluation iteration: %d' % args.eval_iters)

    return train_ds, valid_ds


def process_batch(batch):
    """Generate a batch"""
    args = get_args()
    tokenizer = get_tokenizer()

    # Items and their type.
    keys = ['text', 'loss_mask']
    datatype = torch.int64

    data_b = mpu.broadcast_data(keys, batch, datatype)

    tokens_ = data_b['text'].long()
    labels = tokens_[:, 1:].contiguous()
    tokens = tokens_[:, :-1].contiguous()

    loss_mask = data_b['loss_mask'].float()

    # Get the attention_mask and postition ids.
    attention_mask, position_ids = \
        get_ltor_attention_masks_and_position_ids(tokens, tokenizer.eod_id)

    return tokens, labels, loss_mask, attention_mask, position_ids


def loss_func(loss_mask, output_tensor):
    losses = output_tensor.float()
    loss_mask = loss_mask.view(-1).float()
    loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

    # Reduce loss for logging.
    averaged_loss = average_losses_across_data_parallel_group([loss])

    return loss, {'lm loss': averaged_loss[0]}


def forward_step(batch, model):
    """Forward step."""
    args = get_args()
    timers = get_timers()
    
    try:
        batch_ = next(batch)
    except BaseException:
        batch_ = batch

    tokens, labels, loss_mask, attention_mask, position_ids = process_batch(batch_)

    output_tensor = model(tokens, position_ids, attention_mask,
                          labels=labels)

    return output_tensor, partial(loss_func, loss_mask)


def generate_samples_input_from_file(model):

    args = get_args()
    tokenizer = get_tokenizer()

    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
    if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
            print('`sample-output-file` not specified, setting '
                    'it to {}'.format(sample_output_file))
        else:
            sample_output_file = args.sample_output_file

        fname_out = open(sample_output_file, "w")

    context_count = 0
    model.eval()
    # start the generation process
    with torch.no_grad():
        while True:
            raw_text_len = 0
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
                raw_text = all_raw_text[input_pos]
                input_pos += 1
                raw_text_len = len(raw_text)
                context_tokens = tokenizer.tokenize(raw_text)
            else:
                context_tokens = tokenizer.tokenize("EMPTY TEXT")

            if input_pos % 100 == 0:
                print_rank_0("input_pos: %d" % input_pos)

            # get the generation outputs
            token_stream = get_token_stream(model, [context_tokens])
            for _, decode_tokens in enumerate(token_stream):
                pass

            # write the generation to the output file
            if mpu.get_tensor_model_parallel_rank() == 0:
                if mpu.is_pipeline_first_stage():
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[raw_text_len:]

                    if "\r" in trim_decode_tokens:
                        trim_decode_tokens = trim_decode_tokens.replace("\r", "")
                    if "\n" in trim_decode_tokens:
                        trim_decode_tokens = trim_decode_tokens.replace("\n", "")
                    fname_out.write(trim_decode_tokens)
                    fname_out.write("\n")

            raw_text = None
            context_count += 1

            if input_pos == input_count:
                return


def run_generation(model_provider):

    args = get_args()
    if args.num_layers_per_virtual_pipeline_stage is not None:
        print("Interleaved pipeline schedule is not yet supported for text generation.")
        exit()

    # Set up model and load checkpoint.
    model = get_model(model_provider)

    if args.load is not None:
        _ = load_checkpoint(model, None, None)

    assert len(model) == 1, "Above condition should have caught this"
    model = model[0]

    # run generation
    generate_samples_input_from_file(model)


def main():
    args = get_args()

    if "FINETUNE" in args.task:
        # finetune
        finetune(train_valid_datasets_provider, model_provider, \
                 forward_step=forward_step)
    else:
        # generate
        run_generation(model_provider)