Loading megatron/data_utils/__init__.py +2 −26 Original line number Diff line number Diff line Loading @@ -39,30 +39,6 @@ def should_split(split): """ return max(split)/sum(split) != 1. def get_split(args): """ Get dataset splits from comma separated string list """ splits = [] if args.split.find(',') != -1: splits = [float(s) for s in args.split.split(',')] elif args.split.find('/') != -1: splits = [float(s) for s in args.split.split('/')] else: splits = [float(args.split)] split_total = sum(splits) if split_total < 1.: splits.append(1-split_total) while len(splits) < 3: splits.append(0.) splits = splits[:3] if args.valid_data is not None: splits[1] = 0. if args.test_data is not None: splits[2] = 0. final_sum = sum(splits) return [s/final_sum for s in splits] def get_ext(path): """gets path extension""" return os.path.splitext(path)[1] Loading Loading
megatron/data_utils/__init__.py +2 −26 Original line number Diff line number Diff line Loading @@ -39,30 +39,6 @@ def should_split(split): """ return max(split)/sum(split) != 1. def get_split(args): """ Get dataset splits from comma separated string list """ splits = [] if args.split.find(',') != -1: splits = [float(s) for s in args.split.split(',')] elif args.split.find('/') != -1: splits = [float(s) for s in args.split.split('/')] else: splits = [float(args.split)] split_total = sum(splits) if split_total < 1.: splits.append(1-split_total) while len(splits) < 3: splits.append(0.) splits = splits[:3] if args.valid_data is not None: splits[1] = 0. if args.test_data is not None: splits[2] = 0. final_sum = sum(splits) return [s/final_sum for s in splits] def get_ext(path): """gets path extension""" return os.path.splitext(path)[1] Loading