load_dataset能一次性导入train, val, test

@dc.dataclass
class DataConfig(object):
    train_file: Optional[str] = None
    val_file: Optional[str] = None
    test_file: Optional[str] = None
    num_proc: Optional[int] = None

    @property
    def data_format(self) -> str:
        return Path(self.train_file).suffix

    @property
    def data_files(self) -> dict[NamedSplit, str]:
        return {
            split: data_file
            for split, data_file in zip(
                [Split.TRAIN, Split.VALIDATION, Split.TEST], # 可以写成["a", "b", "c"]只是一个标记的key罢了
                [self.train_file, self.val_file, self.test_file],
            )
            if data_file is not None
        }

def _load_datasets(
        data_dir: str,
        data_format: str,
        data_files: dict[NamedSplit, str],
        num_proc: Optional[int],
) -> DatasetDict:
    if data_format == '.jsonl':
        dataset_dct = load_dataset(
            data_dir,
            data_files=data_files,
            split=None,
            num_proc=num_proc,
        )
    else:
        raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
    return dataset_dct


class DataManager(object):
    def __init__(self, data_dir: str, data_config: DataConfig):
        self._num_proc = data_config.num_proc

        self._dataset_dct = _load_datasets(
            data_dir,
            data_config.data_format,
            data_config.data_files, ## 这边其实是一个字典
            self._num_proc,
        )

    def _get_dataset(self, split) -> Optional[Dataset]:
        return self._dataset_dct.get(split, None)

    def get_dataset(
            self,
            split,
            process_fn: Callable[[dict[str, Any]], dict[str, Any]],
            batched: bool = True,
            remove_orig_columns: bool = True,
    ) -> Optional[Dataset]:
        orig_dataset = self._get_dataset(split)
        if orig_dataset is None:
            return

        if remove_orig_columns:
            remove_columns = orig_dataset.column_names
        else:
            remove_columns = None
        return orig_dataset.map(
            process_fn,
            batched=batched,
            remove_columns=remove_columns,
            num_proc=self._num_proc,
        )

train_dataset = data_manager.get_dataset(
    Split.TRAIN,
    functools.partial(
        process_batch,
        tokenizer=tokenizer,
        combine=ft_config.combine,
        max_input_length=ft_config.max_input_length,
        max_output_length=ft_config.max_output_length,
    ),
    batched=True,
)

再数据导入后,可以使用get来进行获得对于对应key的value,即对应的dataset,然后再使用map函数来进行后续的数据处理,不同的数据集,例如train,val用不同的 map中指定的函数来进行数据处理

data = load_dataset(
    'json',
    data_dir='/home/qyj/code/train_llava/data/LLaVA-CC3M-Pretrain-595K',
    data_files='chat.json',
    split='train[:80%]' # 加载特定的比例的数据
)

data = load_dataset(
    'json',
    data_dir='/home/qyj/code/train_llava/data/LLaVA-CC3M-Pretrain-595K',
    data_files='chat.json',
    split=None 
)
data

 

当你使用 split=None 来加载数据集时,load_dataset 函数会尝试加载整个数据集而不进行任何分割。在某些情况下,即使你没有明确指定数据集的分割方式,load_dataset 可能会自动识别并加载数据集的默认分割。这取决于数据集的结构以及它的文件命名或目录结构。

data = load_dataset(
    'json',
    data_dir='/home/qyj/code/train_llava/data/LLaVA-CC3M-Pretrain-595K',
    data_files='chat.json',
    split=Split.TRAIN # 加载特定的比例的数据
)
data

区别就是唯一的就是要通过一个字典来获得数据集,一个不需要。

### 使用 `load_dataset` 函数加载数据集 #### 加载 Hugging Face 上的数据集 对于来自 Hugging Face 数据库中的公开可用数据集,可以直接通过名称参数调用来获取。例如要加载名为 `cmrc2018` 的中文阅读理解评测数据集: ```python from datasets import load_dataset dataset_dict = load_dataset('cmrc2018') ``` 此方法适用于大多数预构建的数据集集合[^1]。 #### 加载本地文件作为数据集 当处理个人电脑上保存的定制化或私有数据时,则需指明具体文件位置以及其格式(CSV, JSONL, TXT等),并可能涉及额外配置项以适应特定结构需求。下面给出几种常见情形下的实现方式: ##### CSV 文件示例 如果拥有逗号分隔值形式存储的信息表单,那么可以通过如下代码片段读取它成为 Dataset 对象的一部分: ```python import pandas as pd from datasets import DatasetDict, Dataset df_train = pd.read_csv('./local/path/to/train.csv') df_val = pd.read_csv('./local/path/to/validation.csv') ds_train = Dataset.from_pandas(df_train) ds_validation = Dataset.from_pandas(df_val) datasets = DatasetDict({ 'train': ds_train, 'validation': ds_validation }) ``` 上述过程首先利用 Pandas 库解析原始表格资料,接着转换成兼容于 Transformers 生态系统的 Datasets 类型实例[^4]。 ##### 处理分离式的JSON Lines (jsonl) 文件 针对那些已经按照训练/验证阶段划分好并且各自存放在不同 .jsonl 文档里的案例来说,应当采用映射机制来关联各部分至相应标签之下: ```python from datasets import load_dataset datasets = load_dataset( 'json', data_files={ 'train': './path_to_local_data/train.jsonl', 'validation': './path_to_local_data/validation.jsonl' } ) ``` 这段脚本里,“train” 和 “validation”的键名不可更改,因为它们被硬编码到了内部逻辑之中用于区分不同的子集类别[^3]。 ##### 单一多分区 JSON 文件的情况 假定所有样本都被打包在一个单独的大括号包裹着多个数组成员的对象内——其中每个数组对应一种用途(比如 traintestvalidation),则无需再做任何特殊安排即可顺利完成导入操作: ```python from datasets import load_dataset # 假设 json 文件中有三个 key-value 键值对分别为 "train", "validation" 和 "test" datasets = load_dataset('json', data_files='./single_file_with_multiple_splits.json') ``` 在这种情况下,只要确保输入源遵循预期模式就不会遇到困难。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值