Loading megatron/arguments.py +5 −0 Original line number Diff line number Diff line Loading @@ -635,6 +635,9 @@ def _add_data_args(parser): group.add_argument('--retriever-seq-length', type=int, default=256, help='Maximum sequence length for the biencoder model ' ' for retriever') group.add_argument('--sample-rate', type=float, default=1.0, help='sample rate for training data. Supposed to be 0 ' ' < sample_rate < 1') group.add_argument('--mask-prob', type=float, default=0.15, help='Probability of replacing a token with mask.') group.add_argument('--short-seq-prob', type=float, default=0.1, Loading Loading @@ -704,6 +707,8 @@ def _add_biencoder_args(parser): 'ICT dataset') group.add_argument('--use-one-sent-docs', action='store_true', help='Whether to use one sentence documents in ICT') group.add_argument('--evidence-data-path', type=str, default=None, help='Path to Wikipedia Evidence frm DPR paper') # training group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int, Loading megatron/checkpointing.py +13 −13 Original line number Diff line number Diff line Loading @@ -383,42 +383,42 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True return iteration def load_ict_checkpoint(model, only_query_model=False, only_context_model=False, from_realm_chkpt=False): """selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints""" def load_biencoder_checkpoint(model, only_query_model=False, only_context_model=False, custom_load_path=None): """ selectively load retrieval models for indexing/retrieving from saved checkpoints """ args = get_args() model = utils.unwrap_model(model) load_path = args.load if from_realm_chkpt else args.ict_load load_path = custom_load_path if custom_load_path is not None else args.load tracker_filename = get_checkpoint_tracker_filename(load_path) with open(tracker_filename, 'r') as f: iteration = int(f.read().strip()) # assert iteration > 0 checkpoint_name = get_checkpoint_name(load_path, iteration, False) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) state_dict = torch.load(checkpoint_name, map_location='cpu') ict_state_dict = state_dict['model'] print(ict_state_dict) sys.exit() if from_realm_chkpt and mpu.get_data_parallel_rank() == 0: print(" loading ICT state dict from REALM", flush=True) ict_state_dict = ict_state_dict['retriever']['ict_model'] ret_state_dict = state_dict['model'] if only_query_model: ict_state_dict.pop('context_model') ret_state_dict.pop('context_model') if only_context_model: ict_state_dict.pop('query_model') ret_state_dict.pop('query_model') model.load_state_dict(ict_state_dict) assert len(model) == 1 model[0].load_state_dict(ret_state_dict) torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return model megatron/data/biencoder_dataset_utils.py +25 −12 Original line number Diff line number Diff line Loading @@ -4,10 +4,21 @@ import time import numpy as np import torch from megatron import mpu, print_rank_0 from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy from megatron import get_args, get_tokenizer, print_rank_0, mpu from megatron import get_args, get_tokenizer, mpu, print_rank_0 from megatron.data.dataset_utils import create_masked_lm_predictions, \ pad_and_convert_to_numpy from megatron.data.data_samplers import MegatronPretrainingSampler def make_attention_mask(source_block, target_block): """ Returns a 2-dimensional (2-D) attention mask :param source_block: 1-D array :param target_block: 1-D array """ mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) mask = mask.astype(np.int64) # (source_length, target_length) return mask def get_one_epoch_dataloader(dataset, micro_batch_size=None): """Specifically one epoch to be used in an indexing job.""" Loading @@ -20,15 +31,17 @@ def get_one_epoch_dataloader(dataset, micro_batch_size=None): global_batch_size = micro_batch_size * world_size num_workers = args.num_workers sampler = torch.utils.data.SequentialSampler(dataset) # importantly, drop_last must be False to get all the data. assert False, 'DistributedBatchSampler deprecated, change the implementation' from megatron.data.samplers import DistributedBatchSampler batch_sampler = DistributedBatchSampler(sampler, batch_size=global_batch_size, drop_last=False, rank=rank, world_size=world_size) # Use megatron's sampler with consumed samples set to 0 as # this is only for evaluation and don't intend to resume half way. # Also, set the drop last to false as don't intend to remove # the last batch batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=0, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size(), drop_last=False) return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, Loading megatron/data/data_samplers.py +14 −4 Original line number Diff line number Diff line Loading @@ -57,7 +57,7 @@ def build_pretraining_data_loader(dataset, consumed_samples): class MegatronPretrainingSampler: def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size): data_parallel_rank, data_parallel_size, drop_last=True): # Keep a copy of input params for later use. self.total_samples = total_samples self.consumed_samples = consumed_samples Loading @@ -65,6 +65,7 @@ class MegatronPretrainingSampler: self.data_parallel_rank = data_parallel_rank self.micro_batch_times_data_parallel_size = \ self.micro_batch_size * data_parallel_size self.drop_last = drop_last # Sanity checks. assert self.total_samples > 0, \ Loading @@ -81,17 +82,26 @@ class MegatronPretrainingSampler: def __len__(self): return self.total_samples def get_start_end_idx(self): start_idx = self.data_parallel_rank * self.micro_batch_size end_idx = start_idx + self.micro_batch_size return start_idx, end_idx def __iter__(self): batch = [] # Last batch if not complete will be dropped. # Last batch will be dropped if drop_last is not set False for idx in range(self.consumed_samples, self.total_samples): batch.append(idx) if len(batch) == self.micro_batch_times_data_parallel_size: start_idx = self.data_parallel_rank * self.micro_batch_size end_idx = start_idx + self.micro_batch_size start_idx, end_idx = self.get_start_end_idx() yield batch[start_idx:end_idx] batch = [] # Check the last partial batch and see drop_last is set if len(batch) > 0 and not self.drop_last: start_idx, end_idx = self.get_start_end_idx() yield batch[start_idx:end_idx] class MegatronPretrainingRandomSampler: Loading megatron/data/orqa_wiki_dataset.py 0 → 100644 +205 −0 Original line number Diff line number Diff line # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Wikipedia dataset from DPR code for ORQA.""" from abc import ABC import csv import numpy as np import random import torch from torch.utils.data import Dataset from megatron import print_rank_0, get_args, get_tokenizer, mpu from megatron.data.biencoder_dataset_utils import make_attention_mask def get_open_retrieval_wiki_dataset(): args = get_args() tokenizer = get_tokenizer() dataset = OpenRetrievalEvidenceDataset('2018 Wikipedia from DPR codebase', 'evidence', args.evidence_data_path, tokenizer, args.retriever_seq_length) return dataset def get_open_retrieval_batch(data_iterator): # Items and their type. keys = ['row_id', 'context', 'context_mask', 'context_types', 'context_pad_mask'] datatype = torch.int64 # Broadcast data. data = None if data_iterator is None else next(data_iterator) data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. row_id = data_b['row_id'].long() context = data_b['context'].long() # TODO: make the context mask a binary one context_mask = (data_b['context_mask'] < 0.5) context_types = data_b['context_types'].long() context_pad_mask = data_b['context_pad_mask'].long() return row_id, context, context_mask, context_types, context_pad_mask def build_tokens_types_paddings_from_text(row, tokenizer, max_seq_length): """Build token types and paddings, trim if needed, and pad if needed.""" title_ids = tokenizer.tokenize(row['title']) context_ids = tokenizer.tokenize(row['text']) # Appending the title of the context at front extended_context_ids = title_ids + [tokenizer.sep_id] + context_ids context_ids, context_types, context_pad_mask = \ build_tokens_types_paddings_from_ids(extended_context_ids, max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad) return context_ids, context_types, context_pad_mask # noinspection DuplicatedCode def build_tokens_types_paddings_from_ids(text_ids, max_seq_length, cls_id, sep_id, pad_id): """Build token types and paddings, trim if needed, and pad if needed.""" enc_ids = [] tokentypes_enc = [] # [CLS]. enc_ids.append(cls_id) tokentypes_enc.append(0) # A. len_src = len(text_ids) enc_ids.extend(text_ids) tokentypes_enc.extend([0] * len_src) # Cap the size. if len(enc_ids) > max_seq_length - 1: enc_ids = enc_ids[0: max_seq_length - 1] tokentypes_enc = tokentypes_enc[0: max_seq_length - 1] # [SEP]. enc_ids.append(sep_id) tokentypes_enc.append(0) num_tokens_enc = len(enc_ids) # Padding. padding_length = max_seq_length - len(enc_ids) if padding_length > 0: enc_ids.extend([pad_id] * padding_length) tokentypes_enc.extend([pad_id] * padding_length) pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length) pad_mask = np.array(pad_mask, dtype=np.int64) return enc_ids, tokentypes_enc, pad_mask def build_sample(row_id, context_ids, context_types, context_pad_mask): """Convert to numpy and return a sample consumed by the batch producer.""" context_ids = np.array(context_ids, dtype=np.int64) context_types = np.array(context_types, dtype=np.int64) context_mask = make_attention_mask(context_ids, context_ids) sample = ({ 'row_id': row_id, 'context': context_ids, 'context_mask': context_mask, 'context_types': context_types, 'context_pad_mask': context_pad_mask }) return sample class OpenRetrievalEvidenceDataset(ABC, Dataset): """Open Retrieval Evidence dataset class.""" def __init__(self, task_name, dataset_name, datapath, tokenizer, max_seq_length): # Store inputs. self.task_name = task_name self.dataset_name = dataset_name self.tokenizer = tokenizer self.max_seq_length = max_seq_length print_rank_0(' > building {} dataset for {}:'.format(self.task_name, self.dataset_name)) # Process the files. print_rank_0(datapath) self.samples, self.id2text = self.process_samples_from_single_path( datapath) args = get_args() if args.sample_rate < 1: # subsample k = int(len(self.samples) * args.sample_rate) self.samples = random.sample(self.samples, k) print_rank_0(' >> total number of samples: {}'.format( len(self.samples))) def __len__(self): return len(self.samples) def __getitem__(self, idx): row = self.samples[idx] context_ids, context_types, context_pad_mask = \ build_tokens_types_paddings_from_text(row, self.tokenizer, self.max_seq_length) sample = build_sample(row['doc_id'], context_ids, context_types, context_pad_mask) return sample @staticmethod def process_samples_from_single_path(filename): print_rank_0(' > Processing {} ...'.format(filename)) total = 0 rows = [] id2text = {} with open(filename) as tsvfile: reader = csv.reader(tsvfile, delimiter='\t') next(reader, None) # skip the headers for row in reader: # file format: doc_id, doc_text, title doc_id = int(row[0]) text = row[1] title = row[2] rows.append({'doc_id': doc_id, 'text': text, 'title': title}) assert doc_id not in id2text id2text[doc_id] = (text, title) total += 1 if total % 100000 == 0: print_rank_0(' > processed {} rows so far ...'.format( total)) print_rank_0(' >> processed {} samples.'.format(len(rows))) return rows, id2text Loading
megatron/arguments.py +5 −0 Original line number Diff line number Diff line Loading @@ -635,6 +635,9 @@ def _add_data_args(parser): group.add_argument('--retriever-seq-length', type=int, default=256, help='Maximum sequence length for the biencoder model ' ' for retriever') group.add_argument('--sample-rate', type=float, default=1.0, help='sample rate for training data. Supposed to be 0 ' ' < sample_rate < 1') group.add_argument('--mask-prob', type=float, default=0.15, help='Probability of replacing a token with mask.') group.add_argument('--short-seq-prob', type=float, default=0.1, Loading Loading @@ -704,6 +707,8 @@ def _add_biencoder_args(parser): 'ICT dataset') group.add_argument('--use-one-sent-docs', action='store_true', help='Whether to use one sentence documents in ICT') group.add_argument('--evidence-data-path', type=str, default=None, help='Path to Wikipedia Evidence frm DPR paper') # training group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int, Loading
megatron/checkpointing.py +13 −13 Original line number Diff line number Diff line Loading @@ -383,42 +383,42 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True return iteration def load_ict_checkpoint(model, only_query_model=False, only_context_model=False, from_realm_chkpt=False): """selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints""" def load_biencoder_checkpoint(model, only_query_model=False, only_context_model=False, custom_load_path=None): """ selectively load retrieval models for indexing/retrieving from saved checkpoints """ args = get_args() model = utils.unwrap_model(model) load_path = args.load if from_realm_chkpt else args.ict_load load_path = custom_load_path if custom_load_path is not None else args.load tracker_filename = get_checkpoint_tracker_filename(load_path) with open(tracker_filename, 'r') as f: iteration = int(f.read().strip()) # assert iteration > 0 checkpoint_name = get_checkpoint_name(load_path, iteration, False) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) state_dict = torch.load(checkpoint_name, map_location='cpu') ict_state_dict = state_dict['model'] print(ict_state_dict) sys.exit() if from_realm_chkpt and mpu.get_data_parallel_rank() == 0: print(" loading ICT state dict from REALM", flush=True) ict_state_dict = ict_state_dict['retriever']['ict_model'] ret_state_dict = state_dict['model'] if only_query_model: ict_state_dict.pop('context_model') ret_state_dict.pop('context_model') if only_context_model: ict_state_dict.pop('query_model') ret_state_dict.pop('query_model') model.load_state_dict(ict_state_dict) assert len(model) == 1 model[0].load_state_dict(ret_state_dict) torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return model
megatron/data/biencoder_dataset_utils.py +25 −12 Original line number Diff line number Diff line Loading @@ -4,10 +4,21 @@ import time import numpy as np import torch from megatron import mpu, print_rank_0 from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy from megatron import get_args, get_tokenizer, print_rank_0, mpu from megatron import get_args, get_tokenizer, mpu, print_rank_0 from megatron.data.dataset_utils import create_masked_lm_predictions, \ pad_and_convert_to_numpy from megatron.data.data_samplers import MegatronPretrainingSampler def make_attention_mask(source_block, target_block): """ Returns a 2-dimensional (2-D) attention mask :param source_block: 1-D array :param target_block: 1-D array """ mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) mask = mask.astype(np.int64) # (source_length, target_length) return mask def get_one_epoch_dataloader(dataset, micro_batch_size=None): """Specifically one epoch to be used in an indexing job.""" Loading @@ -20,15 +31,17 @@ def get_one_epoch_dataloader(dataset, micro_batch_size=None): global_batch_size = micro_batch_size * world_size num_workers = args.num_workers sampler = torch.utils.data.SequentialSampler(dataset) # importantly, drop_last must be False to get all the data. assert False, 'DistributedBatchSampler deprecated, change the implementation' from megatron.data.samplers import DistributedBatchSampler batch_sampler = DistributedBatchSampler(sampler, batch_size=global_batch_size, drop_last=False, rank=rank, world_size=world_size) # Use megatron's sampler with consumed samples set to 0 as # this is only for evaluation and don't intend to resume half way. # Also, set the drop last to false as don't intend to remove # the last batch batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=0, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size(), drop_last=False) return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, Loading
megatron/data/data_samplers.py +14 −4 Original line number Diff line number Diff line Loading @@ -57,7 +57,7 @@ def build_pretraining_data_loader(dataset, consumed_samples): class MegatronPretrainingSampler: def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size): data_parallel_rank, data_parallel_size, drop_last=True): # Keep a copy of input params for later use. self.total_samples = total_samples self.consumed_samples = consumed_samples Loading @@ -65,6 +65,7 @@ class MegatronPretrainingSampler: self.data_parallel_rank = data_parallel_rank self.micro_batch_times_data_parallel_size = \ self.micro_batch_size * data_parallel_size self.drop_last = drop_last # Sanity checks. assert self.total_samples > 0, \ Loading @@ -81,17 +82,26 @@ class MegatronPretrainingSampler: def __len__(self): return self.total_samples def get_start_end_idx(self): start_idx = self.data_parallel_rank * self.micro_batch_size end_idx = start_idx + self.micro_batch_size return start_idx, end_idx def __iter__(self): batch = [] # Last batch if not complete will be dropped. # Last batch will be dropped if drop_last is not set False for idx in range(self.consumed_samples, self.total_samples): batch.append(idx) if len(batch) == self.micro_batch_times_data_parallel_size: start_idx = self.data_parallel_rank * self.micro_batch_size end_idx = start_idx + self.micro_batch_size start_idx, end_idx = self.get_start_end_idx() yield batch[start_idx:end_idx] batch = [] # Check the last partial batch and see drop_last is set if len(batch) > 0 and not self.drop_last: start_idx, end_idx = self.get_start_end_idx() yield batch[start_idx:end_idx] class MegatronPretrainingRandomSampler: Loading
megatron/data/orqa_wiki_dataset.py 0 → 100644 +205 −0 Original line number Diff line number Diff line # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Wikipedia dataset from DPR code for ORQA.""" from abc import ABC import csv import numpy as np import random import torch from torch.utils.data import Dataset from megatron import print_rank_0, get_args, get_tokenizer, mpu from megatron.data.biencoder_dataset_utils import make_attention_mask def get_open_retrieval_wiki_dataset(): args = get_args() tokenizer = get_tokenizer() dataset = OpenRetrievalEvidenceDataset('2018 Wikipedia from DPR codebase', 'evidence', args.evidence_data_path, tokenizer, args.retriever_seq_length) return dataset def get_open_retrieval_batch(data_iterator): # Items and their type. keys = ['row_id', 'context', 'context_mask', 'context_types', 'context_pad_mask'] datatype = torch.int64 # Broadcast data. data = None if data_iterator is None else next(data_iterator) data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. row_id = data_b['row_id'].long() context = data_b['context'].long() # TODO: make the context mask a binary one context_mask = (data_b['context_mask'] < 0.5) context_types = data_b['context_types'].long() context_pad_mask = data_b['context_pad_mask'].long() return row_id, context, context_mask, context_types, context_pad_mask def build_tokens_types_paddings_from_text(row, tokenizer, max_seq_length): """Build token types and paddings, trim if needed, and pad if needed.""" title_ids = tokenizer.tokenize(row['title']) context_ids = tokenizer.tokenize(row['text']) # Appending the title of the context at front extended_context_ids = title_ids + [tokenizer.sep_id] + context_ids context_ids, context_types, context_pad_mask = \ build_tokens_types_paddings_from_ids(extended_context_ids, max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad) return context_ids, context_types, context_pad_mask # noinspection DuplicatedCode def build_tokens_types_paddings_from_ids(text_ids, max_seq_length, cls_id, sep_id, pad_id): """Build token types and paddings, trim if needed, and pad if needed.""" enc_ids = [] tokentypes_enc = [] # [CLS]. enc_ids.append(cls_id) tokentypes_enc.append(0) # A. len_src = len(text_ids) enc_ids.extend(text_ids) tokentypes_enc.extend([0] * len_src) # Cap the size. if len(enc_ids) > max_seq_length - 1: enc_ids = enc_ids[0: max_seq_length - 1] tokentypes_enc = tokentypes_enc[0: max_seq_length - 1] # [SEP]. enc_ids.append(sep_id) tokentypes_enc.append(0) num_tokens_enc = len(enc_ids) # Padding. padding_length = max_seq_length - len(enc_ids) if padding_length > 0: enc_ids.extend([pad_id] * padding_length) tokentypes_enc.extend([pad_id] * padding_length) pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length) pad_mask = np.array(pad_mask, dtype=np.int64) return enc_ids, tokentypes_enc, pad_mask def build_sample(row_id, context_ids, context_types, context_pad_mask): """Convert to numpy and return a sample consumed by the batch producer.""" context_ids = np.array(context_ids, dtype=np.int64) context_types = np.array(context_types, dtype=np.int64) context_mask = make_attention_mask(context_ids, context_ids) sample = ({ 'row_id': row_id, 'context': context_ids, 'context_mask': context_mask, 'context_types': context_types, 'context_pad_mask': context_pad_mask }) return sample class OpenRetrievalEvidenceDataset(ABC, Dataset): """Open Retrieval Evidence dataset class.""" def __init__(self, task_name, dataset_name, datapath, tokenizer, max_seq_length): # Store inputs. self.task_name = task_name self.dataset_name = dataset_name self.tokenizer = tokenizer self.max_seq_length = max_seq_length print_rank_0(' > building {} dataset for {}:'.format(self.task_name, self.dataset_name)) # Process the files. print_rank_0(datapath) self.samples, self.id2text = self.process_samples_from_single_path( datapath) args = get_args() if args.sample_rate < 1: # subsample k = int(len(self.samples) * args.sample_rate) self.samples = random.sample(self.samples, k) print_rank_0(' >> total number of samples: {}'.format( len(self.samples))) def __len__(self): return len(self.samples) def __getitem__(self, idx): row = self.samples[idx] context_ids, context_types, context_pad_mask = \ build_tokens_types_paddings_from_text(row, self.tokenizer, self.max_seq_length) sample = build_sample(row['doc_id'], context_ids, context_types, context_pad_mask) return sample @staticmethod def process_samples_from_single_path(filename): print_rank_0(' > Processing {} ...'.format(filename)) total = 0 rows = [] id2text = {} with open(filename) as tsvfile: reader = csv.reader(tsvfile, delimiter='\t') next(reader, None) # skip the headers for row in reader: # file format: doc_id, doc_text, title doc_id = int(row[0]) text = row[1] title = row[2] rows.append({'doc_id': doc_id, 'text': text, 'title': title}) assert doc_id not in id2text id2text[doc_id] = (text, title) total += 1 if total % 100000 == 0: print_rank_0(' > processed {} rows so far ...'.format( total)) print_rank_0(' >> processed {} samples.'.format(len(rows))) return rows, id2text