Loading tasks/knwl_dialo/evaluate.py +14 −0 Original line number Diff line number Diff line # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Model evaluation""" Loading tasks/knwl_dialo/preprocessing.py +14 −0 Original line number Diff line number Diff line # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets""" Loading tasks/knwl_dialo/prompt.py +29 −1 Original line number Diff line number Diff line # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Prompting the pretrained language model to generate knowledge/response""" Loading @@ -13,6 +27,7 @@ from megatron.training import get_model from megatron.checkpointing import load_checkpoint from megatron.initialize import initialize_megatron from tasks.knwl_dialo.utils import get_token_stream # from megatron.text_generation import generate_and_post_process def model_provider(pre_process=True, post_process=True): Loading Loading @@ -137,6 +152,7 @@ def generate_samples_by_prompting_input_from_file(model): else: context_tokens = tokenizer.tokenize("EMPTY TEXT") # raw_text = "EMPTY TEXT" if input_pos % 100 == 0: print_rank_0("input_pos: %d" % input_pos) Loading @@ -145,6 +161,12 @@ def generate_samples_by_prompting_input_from_file(model): token_stream = get_token_stream(model, [context_tokens]) for _, decode_tokens in enumerate(token_stream): pass # outputs = generate_and_post_process( # model=model, # prompts=[raw_text], # tokens_to_generate=args.out_seq_length, # top_k_sampling=1) # prompts_plus_generations = outputs[0] # write the generated output to the output file if mpu.get_tensor_model_parallel_rank() == 0: Loading @@ -159,6 +181,12 @@ def generate_samples_by_prompting_input_from_file(model): fname_out.write(generated_output) fname_out.write("\n") # generations = prompts_plus_generations[raw_text_len:] # generations = generations.split("\n")[0] # generations = generations.strip() # fname_out.write(generations) # fname_out.write("\n") raw_text = None context_count += 1 if input_pos == input_count: Loading tasks/knwl_dialo/utils.py +14 −0 Original line number Diff line number Diff line # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utils (functions) for both prompting and finetuning""" Loading Loading
tasks/knwl_dialo/evaluate.py +14 −0 Original line number Diff line number Diff line # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Model evaluation""" Loading
tasks/knwl_dialo/preprocessing.py +14 −0 Original line number Diff line number Diff line # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets""" Loading
tasks/knwl_dialo/prompt.py +29 −1 Original line number Diff line number Diff line # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Prompting the pretrained language model to generate knowledge/response""" Loading @@ -13,6 +27,7 @@ from megatron.training import get_model from megatron.checkpointing import load_checkpoint from megatron.initialize import initialize_megatron from tasks.knwl_dialo.utils import get_token_stream # from megatron.text_generation import generate_and_post_process def model_provider(pre_process=True, post_process=True): Loading Loading @@ -137,6 +152,7 @@ def generate_samples_by_prompting_input_from_file(model): else: context_tokens = tokenizer.tokenize("EMPTY TEXT") # raw_text = "EMPTY TEXT" if input_pos % 100 == 0: print_rank_0("input_pos: %d" % input_pos) Loading @@ -145,6 +161,12 @@ def generate_samples_by_prompting_input_from_file(model): token_stream = get_token_stream(model, [context_tokens]) for _, decode_tokens in enumerate(token_stream): pass # outputs = generate_and_post_process( # model=model, # prompts=[raw_text], # tokens_to_generate=args.out_seq_length, # top_k_sampling=1) # prompts_plus_generations = outputs[0] # write the generated output to the output file if mpu.get_tensor_model_parallel_rank() == 0: Loading @@ -159,6 +181,12 @@ def generate_samples_by_prompting_input_from_file(model): fname_out.write(generated_output) fname_out.write("\n") # generations = prompts_plus_generations[raw_text_len:] # generations = generations.split("\n")[0] # generations = generations.strip() # fname_out.write(generations) # fname_out.write("\n") raw_text = None context_count += 1 if input_pos == input_count: Loading
tasks/knwl_dialo/utils.py +14 −0 Original line number Diff line number Diff line # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utils (functions) for both prompting and finetuning""" Loading