[SPARK-54337][PS] Add support for PyCapsule to Pyspark · apache/spark@ecf179c

3 min read Original article ↗
1+

#

2+

# Licensed to the Apache Software Foundation (ASF) under one or more

3+

# contributor license agreements. See the NOTICE file distributed with

4+

# this work for additional information regarding copyright ownership.

5+

# The ASF licenses this file to You under the Apache License, Version 2.0

6+

# (the "License"); you may not use this file except in compliance with

7+

# the License. You may obtain a copy of the License at

8+

#

9+

# http://www.apache.org/licenses/LICENSE-2.0

10+

#

11+

# Unless required by applicable law or agreed to in writing, software

12+

# distributed under the License is distributed on an "AS IS" BASIS,

13+

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

14+

# See the License for the specific language governing permissions and

15+

# limitations under the License.

16+

#

17+

from typing import Iterator, Optional

18+19+

import pyarrow as pa

20+21+

import pyspark.sql

22+

from pyspark.sql.types import StructType, StructField, BinaryType

23+

from pyspark.sql.pandas.types import to_arrow_schema

24+25+26+

def _get_arrow_array_partition_stream(df: pyspark.sql.DataFrame) -> Iterator[pa.RecordBatch]:

27+

"""Return all the partitions as Arrow arrays in an Iterator."""

28+

# We will be using mapInArrow to convert each partition to Arrow RecordBatches.

29+

# The return type of the function will be a single binary column containing

30+

# the serialized RecordBatch in Arrow IPC format.

31+

binary_schema = StructType([StructField("arrow_ipc_bytes", BinaryType(), nullable=False)])

32+33+

def batch_to_bytes_iter(batch_iter: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]:

34+

"""

35+

A generator function that converts RecordBatches to serialized Arrow IPC format.

36+37+

Spark sends each partition as an iterator of RecordBatches. In order to return

38+

the entire partition as a stream of Arrow RecordBatches, we need to serialize

39+

each RecordBatch to Arrow IPC format and yield it as a single binary blob.

40+

"""

41+

# The size of the batch can be controlled by the Spark config

42+

# `spark.sql.execution.arrow.maxRecordsPerBatch`.

43+

for arrow_batch in batch_iter:

44+

# We create an in-memory byte stream to hold the serialized batch

45+

sink = pa.BufferOutputStream()

46+

# Write the batch to the stream using Arrow IPC format

47+

with pa.ipc.new_stream(sink, arrow_batch.schema) as writer:

48+

writer.write_batch(arrow_batch)

49+

buf = sink.getvalue()

50+

# The second buffer contains the offsets we are manually creating.

51+

offset_buf = pa.array([0, len(buf)], type=pa.int32()).buffers()[1]

52+

null_bitmap = None

53+

# Wrap the bytes in a new 1-row, 1-column RecordBatch to satisfy mapInArrow return

54+

# signature. This serializes the whole batch into a single pyarrow serialized cell.

55+

storage_arr = pa.Array.from_buffers(

56+

type=pa.binary(), length=1, buffers=[null_bitmap, offset_buf, buf]

57+

)

58+

yield pa.RecordBatch.from_arrays([storage_arr], names=["arrow_ipc_bytes"])

59+60+

# Convert all partitions to Arrow RecordBatches and map to binary blobs.

61+

byte_df = df.mapInArrow(batch_to_bytes_iter, binary_schema)

62+

# A row is actually a batch of data in Arrow IPC format. Fetch the batches one by one.

63+

for row in byte_df.toLocalIterator():

64+

with pa.ipc.open_stream(row.arrow_ipc_bytes) as reader:

65+

for batch in reader:

66+

# Each batch corresponds to a chunk of data in the partition.

67+

yield batch

68+69+70+

class SparkArrowCStreamer:

71+

"""

72+

A class that implements that __arrow_c_stream__ protocol for Spark partitions.

73+74+

This class is implemented in a way that allows consumers to consume each partition

75+

one at a time without materializing all partitions at once on the driver side.

76+

"""

77+78+

def __init__(self, df: pyspark.sql.DataFrame):

79+

self._df = df

80+

self._schema = to_arrow_schema(df.schema)

81+82+

def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> object:

83+

"""

84+

Return the Arrow C stream for the dataframe partitions.

85+

"""

86+

reader: pa.RecordBatchReader = pa.RecordBatchReader.from_batches(

87+

self._schema, _get_arrow_array_partition_stream(self._df)

88+

)

89+

return reader.__arrow_c_stream__(requested_schema=requested_schema)