ray dataset与spark2.x dataframe数据之间转换

本文介绍如何在Ray分布式计算框架中进行数据转换,特别是如何在Ray的Dataset和Spark2.x的DataFrame之间进行操作。由于Ray的原生接口仅支持Spark3.x,本文将分享针对Spark2.x的自定义转换方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

ray分布式计算框架可以读取许多类型的文件,比如parquet,csv,json,numpy,images,binary,tfrecords等,也可以从不同的系统中读取数据,比如s3, hdfs,GCS等。今天写一个比较常用的数据转换操作——与spark之间的数据转换。ray中已经有读取和写入spark的接口,不过它只支持spark3.x,不支持spark2.x的版本,因此我根据源码修改了部分内容以适应spark2.x dataframe与ray dataset之间的转换。

相应的包版本:

ray==2.0.0
spark==2.4.6
pickle5==0.0.12
pyarrow==6.0.1
pandas==1.3.5
# -*- coding: utf-8 -*-
import os
import sys
import time
from distutils.version import LooseVersion
from typing import Union, List, Optional

import pyspark.sql
import ray
from pyspark.serializers import ArrowStreamSerializer
from pyspark.sql.types import DataType, StructType
from pyspark.traceback_utils import SCCallSiteSync

import pyarrow as pa

from ray.types import ObjectRef
from ray.data import Dataset
from ray.data._internal.arrow_block import ArrowRow
from ray.data._internal.block_list import BlockList
from ray.data._internal.plan import ExecutionPlan
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.stats import DatasetStats
from ray.data.block import BlockExecStats, BlockAccessor, BlockMetadata


def _get_metadata(table: Union["pyarrow.Table", "pandas.DataFrame"]) -> BlockMetadata:
    """get pyarrow table metadata"""
    stats = BlockExecStats.builder()
    return BlockAccessor.for_block(table).get_metadata(input_files=[], exec_stats=stats.build())


class SparkDFToRayDataset(object):
    """Convert Pyspark DataFrame to Ray dataset[ArrowRow]"""
    def __init__(self, sdf: pyspark.sql.DataFrame):
        self._df = sdf
        # noinspection PyProtectedMember
        self._jdf = sdf._jdf
        # noinspection PyProtectedMember
        self._sc = sdf.sql_ctx._sc

    def _collect_arrow_table_refs(self, batch_size):
        """
        Returns all records as a list of Arrow table object references, pyarrow must be installed
        and available on driver and worker Python environments.
        :param batch_size: batch_size of pyarrow table to transfer to ray object reference store
        :return:
        """
        with SCCallSiteSync(self._sc):
            # noinspection PyProtectedMember
            from pyspark.rdd import _load_from_socket
            port, auth_secret, jsocket_auth_server = self._jdf.collectAsArrowToPython()
            try:
                buffer_batches = []
                obj_refs = []
                for arrow_record_batch in _load_from_socket((port, auth_secret), ArrowStreamSerializer()):
                    buffer_batches.append(arrow_record_batch)
                    if len(buffer_batches) >= batch_size:
                        _pa_table = pa.</
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值