1+#
2+# Licensed to the Apache Software Foundation (ASF) under one or more
3+# contributor license agreements. See the NOTICE file distributed with
4+# this work for additional information regarding copyright ownership.
5+# The ASF licenses this file to You under the Apache License, Version 2.0
6+# (the "License"); you may not use this file except in compliance with
7+# the License. You may obtain a copy of the License at
8+#
9+# http://www.apache.org/licenses/LICENSE-2.0
10+#
11+# Unless required by applicable law or agreed to in writing, software
12+# distributed under the License is distributed on an "AS IS" BASIS,
13+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+# See the License for the specific language governing permissions and
15+# limitations under the License.
16+#
17+from typing import Iterator, Optional
18+19+import pyarrow as pa
20+21+import pyspark.sql
22+from pyspark.sql.types import StructType, StructField, BinaryType
23+from pyspark.sql.pandas.types import to_arrow_schema
24+25+26+def _get_arrow_array_partition_stream(df: pyspark.sql.DataFrame) -> Iterator[pa.RecordBatch]:
27+"""Return all the partitions as Arrow arrays in an Iterator."""
28+# We will be using mapInArrow to convert each partition to Arrow RecordBatches.
29+# The return type of the function will be a single binary column containing
30+# the serialized RecordBatch in Arrow IPC format.
31+binary_schema = StructType([StructField("arrow_ipc_bytes", BinaryType(), nullable=False)])
32+33+def batch_to_bytes_iter(batch_iter: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]:
34+"""
35+ A generator function that converts RecordBatches to serialized Arrow IPC format.
36+37+ Spark sends each partition as an iterator of RecordBatches. In order to return
38+ the entire partition as a stream of Arrow RecordBatches, we need to serialize
39+ each RecordBatch to Arrow IPC format and yield it as a single binary blob.
40+ """
41+# The size of the batch can be controlled by the Spark config
42+# `spark.sql.execution.arrow.maxRecordsPerBatch`.
43+for arrow_batch in batch_iter:
44+# We create an in-memory byte stream to hold the serialized batch
45+sink = pa.BufferOutputStream()
46+# Write the batch to the stream using Arrow IPC format
47+with pa.ipc.new_stream(sink, arrow_batch.schema) as writer:
48+writer.write_batch(arrow_batch)
49+buf = sink.getvalue()
50+# The second buffer contains the offsets we are manually creating.
51+offset_buf = pa.array([0, len(buf)], type=pa.int32()).buffers()[1]
52+null_bitmap = None
53+# Wrap the bytes in a new 1-row, 1-column RecordBatch to satisfy mapInArrow return
54+# signature. This serializes the whole batch into a single pyarrow serialized cell.
55+storage_arr = pa.Array.from_buffers(
56+type=pa.binary(), length=1, buffers=[null_bitmap, offset_buf, buf]
57+ )
58+yield pa.RecordBatch.from_arrays([storage_arr], names=["arrow_ipc_bytes"])
59+60+# Convert all partitions to Arrow RecordBatches and map to binary blobs.
61+byte_df = df.mapInArrow(batch_to_bytes_iter, binary_schema)
62+# A row is actually a batch of data in Arrow IPC format. Fetch the batches one by one.
63+for row in byte_df.toLocalIterator():
64+with pa.ipc.open_stream(row.arrow_ipc_bytes) as reader:
65+for batch in reader:
66+# Each batch corresponds to a chunk of data in the partition.
67+yield batch
68+69+70+class SparkArrowCStreamer:
71+"""
72+ A class that implements that __arrow_c_stream__ protocol for Spark partitions.
73+74+ This class is implemented in a way that allows consumers to consume each partition
75+ one at a time without materializing all partitions at once on the driver side.
76+ """
77+78+def __init__(self, df: pyspark.sql.DataFrame):
79+self._df = df
80+self._schema = to_arrow_schema(df.schema)
81+82+def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> object:
83+"""
84+ Return the Arrow C stream for the dataframe partitions.
85+ """
86+reader: pa.RecordBatchReader = pa.RecordBatchReader.from_batches(
87+self._schema, _get_arrow_array_partition_stream(self._df)
88+ )
89+return reader.__arrow_c_stream__(requested_schema=requested_schema)