ray分布式计算框架可以读取许多类型的文件,比如parquet,csv,json,numpy,images,binary,tfrecords
等,也可以从不同的系统中读取数据,比如s3, hdfs,GCS
等。今天写一个比较常用的数据转换操作——与spark之间的数据转换。ray中已经有读取和写入spark的接口,不过它只支持spark3.x,不支持spark2.x的版本,因此我根据源码修改了部分内容以适应spark2.x dataframe与ray dataset之间的转换。
相应的包版本:
ray==2.0.0
spark==2.4.6
pickle5==0.0.12
pyarrow==6.0.1
pandas==1.3.5
# -*- coding: utf-8 -*-
import os
import sys
import time
from distutils.version import LooseVersion
from typing import Union, List, Optional
import pyspark.sql
import ray
from pyspark.serializers import ArrowStreamSerializer
from pyspark.sql.types import DataType, StructType
from pyspark.traceback_utils import SCCallSiteSync
import pyarrow as pa
from ray.types import ObjectRef
from ray.data import Dataset
from ray.data._internal.arrow_block import ArrowRow
from ray.data._internal.block_list import BlockList
from ray.data._internal.plan import ExecutionPlan
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.stats import DatasetStats
from ray.data.block import BlockExecStats, BlockAccessor, BlockMetadata
def _get_metadata(table: Union["pyarrow.Table", "pandas.DataFrame"]) -> BlockMetadata:
"""get pyarrow table metadata"""
stats = BlockExecStats.builder()
return BlockAccessor.for_block(table).get_metadata(input_files=[], exec_stats=stats.build())
class SparkDFToRayDataset(object):
"""Convert Pyspark DataFrame to Ray dataset[ArrowRow]"""
def __init__(self, sdf: pyspark.sql.DataFrame):
self._df = sdf
# noinspection PyProtectedMember
self._jdf = sdf._jdf
# noinspection PyProtectedMember
self._sc = sdf.sql_ctx._sc
def _collect_arrow_table_refs(self, batch_size):
"""
Returns all records as a list of Arrow table object references, pyarrow must be installed
and available on driver and worker Python environments.
:param batch_size: batch_size of pyarrow table to transfer to ray object reference store
:return:
"""
with SCCallSiteSync(self._sc):
# noinspection PyProtectedMember
from pyspark.rdd import _load_from_socket
port, auth_secret, jsocket_auth_server = self._jdf.collectAsArrowToPython()
try:
buffer_batches = []
obj_refs = []
for arrow_record_batch in _load_from_socket((port, auth_secret), ArrowStreamSerializer()):
buffer_batches.append(arrow_record_batch)
if len(buffer_batches) >= batch_size:
_pa_table = pa.</