Spark toPandas() with Arrow, a Detailed Look

The upcoming release of Apache Spark 2.3 will include Apache Arrow as a dependency. For those that do not know, Arrow is an in-memory columnar data format with APIs in Java, C++, and Python. Since Spark does a lot of data transfer between the JVM and Python, this is particularly useful and can really help optimize the performance of PySpark. In my post on the Arrow blog, I showed a basic example on how to enable Arrow for a much more efficient conversion of a Spark DataFrame to Pandas. Following that, this post will take a more detailed look at how this is done internally in Spark, why it leads to such a dramatic speedup, and what else can be improved upon in the future.

Where are the bottlenecks in Pandas Conversion?

Let’s start by looking at the simple example code that makes a Spark distributed DataFrame and then converts it to a local Pandas DataFrame without using Arrow:

from pyspark.sql.functions import rand
df = spark.range(1 << 22).toDF("id").withColumn("x", rand())
pandas_df = df.toPandas()

Running this locally on my laptop completes with a wall time of ~20.5s. The initial command spark.range() will actually create partitions of data in the JVM where each record is a Row consisting of a long “id” and double “x.” The next command toPandas() will kick off the entire process on the distributed data and convert it to a Pandas.DataFrame. Before any conversion happens, Spark will simply collect all the partitioned data onto the driver, which yields a huge array of Rows. This is not necessarily a slow process because the collection of Rows is optimized with compression and the data still remains in the JVM, but this presents the first big question of how is all of this collected data going to end up in my Python process?

Spark communicates to Python over sockets with serializers/deserializers at each end. The Python deserializer pyspark.serializers.PickleSerializer uses the cPickle module with the standard pickle format. On the Java side, Spark uses the Pyrolite library in org.apache.spark.api.python.SerDeUtil.AutoBatchedPickler which can serialize Java objects into the pickle format. Before this can happen though, the data must first pass through an initial conversion to massage out any incompatibilities between Scala and Java. The raw data is batched up to not overload the memory, and these batches of pickled data are then served to the Python process. It is important to note that while this is done in batches, each individual scalar value must be processed and serialized.

Once Python receives a batch of pickled data, it is deserialized into a list of Python Rows, which are basically tuples of data. This is done for all batches and then the lists are concatenated together into a single huge list. This single list is fed into the function pandas.DataFrame.from_records that will then produce the final Pandas DataFrame. Since this data is in pure Python objects and Pandas data is based in Numpy arrays, it must again go through another conversion, which requires a full iteration and is also costly.

With all that is going on in this process, it raises a number of performance questions. The Python pickle format is known to be rather slow, so how much of the time spent is simply doing SerDe? When such a large volume of data is being transferred, is there any way to work with large chunks at a time instead of individual scalars? Is there a more efficient way to produce a Pandas DataFrame? To help answer these questions, let’s first look at a profile of the Python driver process, sorted by cumulative time showing the top time-consuming calls:

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   23.013   23.013 <string>:1(<module>)
        1    0.456    0.456   23.013   23.013 dataframe.py:1712(toPandas)
        1    0.092    0.092   21.222   21.222 dataframe.py:439(collect)
       81    0.000    0.000   20.195    0.249 serializers.py:141(load_stream)
       81    0.001    0.000   20.194    0.249 serializers.py:160(_read_with_length)
       80    0.000    0.000   20.167    0.252 serializers.py:470(loads)
       80    3.280    0.041   20.167    0.252 {cPickle.loads}
  4194304    1.024    0.000   16.295    0.000 types.py:1532(<lambda>)
  4194304    2.048    0.000   15.270    0.000 types.py:610(fromInternal)
  4194304    9.956    0.000   12.552    0.000 types.py:1535(_create_row)
  4194304    1.105    0.000    1.807    0.000 types.py:1583(__new__)
        1    0.000    0.000    1.335    1.335 frame.py:969(from_records)
        1    0.047    0.047    1.321    1.321 frame.py:5693(_to_arrays)
        1    0.000    0.000    1.274    1.274 frame.py:5787(_list_to_arrays)
      165    0.958    0.006    0.958    0.006 {method 'recv' of '_socket.socket' objects}
        4    0.000    0.000    0.935    0.234 java_gateway.py:1150(__call__)
        4    0.000    0.000    0.935    0.234 java_gateway.py:885(send_command)
        4    0.000    0.000    0.935    0.234 java_gateway.py:1033(send_command)
        4    0.000    0.000    0.934    0.234 socket.py:410(readline)
  4194304    0.789    0.000    0.789    0.000 types.py:1667(__setattr__)
        1    0.000    0.000    0.759    0.759 frame.py:5846(_convert_object_array)
        2    0.000    0.000    0.759    0.380 frame.py:5856(convert)
        2    0.759    0.380    0.759    0.380 {pandas._libs.lib.maybe_convert_objects}
  4194308    0.702    0.000    0.702    0.000 {built-in method __new__ of type object at 0x7fa547e394c0}
  4195416    0.671    0.000    0.671    0.000 {isinstance}
  4194304    0.586    0.000    0.586    0.000 types.py:1531(_create_row_inbound_converter)
        1    0.515    0.515    0.515    0.515 {pandas._libs.lib.to_object_array_tuples}

This shows that a large part of time is spent in the Python deserializer. There are 80 calls to serializers.py:470(loads) which correspond to 80 batches of pickled data. The most time is in the creation of Python Rows at types.py:1535(_create_row) which is invoked by the Row.__reduce__ method 4,194,304 times as each row is deserialized. The remaining time is made up of bringing the data into a Pandas DataFrame and some low level IO. Now that we now what is causing this poor performance, let’s see what Arrow can do to improve it.

Using Arrow to Optimize Conversion

Because Arrow defines a common data format across different language implementations, it is possible to transfer data from Java to Python without any conversions or processing. This means that a Spark DataFrame, which resides in the JVM, can be easily made into Arrow data in Java and then sent as a whole to Python where it is directly consumed. This eliminates the need for any of the costly serialization we saw before and allows transferring of large chunks of data at a time. To make matters sweeter, the Python implementation of Arrow, pyarrow, has built-in conversions for Pandas that will efficiently create a Numpy based DataFrame from the Arrow data, utilizing zero-copy methods when possible. Let’s take a look at the profile for the same conversion process with Arrow. First, enable Arrow in Spark with SQLConf using the command:

spark.conf.set("spark.sql.execution.arrow.enabled", "true")

Or by adding “spark.sql.execution.arrow.enabled=true” to your Spark configuration at conf/spark-defaults.conf

Now taking the profile from the same example code above:

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.001    0.001    0.457    0.457 <string>:1(<module>)
        1    0.000    0.000    0.456    0.456 dataframe.py:1712(toPandas)
        1    0.000    0.000    0.442    0.442 dataframe.py:1754(_collectAsArrow)
       53    0.404    0.008    0.404    0.008 {method 'recv' of '_socket.socket' objects}
        4    0.000    0.000    0.389    0.097 java_gateway.py:1150(__call__)
        4    0.000    0.000    0.389    0.097 java_gateway.py:885(send_command)
        4    0.000    0.000    0.389    0.097 java_gateway.py:1033(send_command)
        4    0.000    0.000    0.389    0.097 socket.py:410(readline)
        9    0.000    0.000    0.053    0.006 serializers.py:141(load_stream)
        9    0.000    0.000    0.053    0.006 serializers.py:160(_read_with_length)
       17    0.001    0.000    0.052    0.003 socket.py:340(read)
       48    0.022    0.000    0.022    0.000 {method 'write' of 'cStringIO.StringO' objects}
       13    0.014    0.001    0.014    0.001 {method 'getvalue' of 'cStringIO.StringO' objects}
        1    0.000    0.000    0.013    0.013 {method 'to_pandas' of 'pyarrow.lib.Table' objects}
        1    0.000    0.000    0.013    0.013 pandas_compat.py:107(table_to_blockmanager)
        1    0.013    0.013    0.013    0.013 {pyarrow.lib.table_to_blocks}

This now completes on my laptop with a wall time of 692ms - which is much much more reasonable! Now that all of the serialization and processing is out of the way, the time is now mostly due to IO. Of that what we see here with java_gateway.py:885(send_command) is from Spark’s use of Py4J and not directly the data transfer to pyarrow. This process works in a slightly different way than without Arrow. Since we are creating Arrow data in Java, that conversion can be pushed down into the executors to be done in parallel and then the Arrow data is collected, instead of Rows. Once the Arrow data is received by the Python driver process, the Arrow data is contatenated into one Arrow.Table, although memory is not copied just appended as chunks. From this, pyarrow will output a single Pandas DataFrame. Getting back to the Py4J commands, they are not taking long themselves but the time is due to waiting for the Arrow data to be produced on the JVM. Even though this is a drastic speedup to before, there is still more room for improvement.

Possible Future Improvements

The majority of the time spent in creating Arrow data in Java is because Spark internally stores data in row form and that must converted to column form by iterating over each row. There has been some discussion in adding a columnar engine to Spark in SPARK-15687 which would be the best way to improve this and I have started to add support for using columnar batches with Arrow in SPARK-21583.

Besides speedups, memory usage could also be improved. The current conversion to Arrow data requires the data to be copied to column form, then also written to a temporary buffer to be collected over a socket. The use of the streaming Arrow format could remove the need for a temporary buffer, as Arrow data could be written right to the socket as it is ready.

Currently, I am working on 2 additional uses of Arrow for PySpark users. Using vectorized Python UDFs with Arrow to transfer data will also give a performance boost for using high-level SparkSQL when it’s necessary to have some custom Python code. Keep a watch on SPARK-21404 for this. Creating a Spark DataFrame converted from a Pandas DataFrame (the opposite direction of toPandas()) actually goes through even more conversion and bottlenecks if you can believe it. Using Arrow for this is being working on in SPARK-20791 and should give similar performance improvements and make for a very efficient round-trip with Pandas.

Tuning and Usage Notes

When enabling Arrow, the batch size of Arrow data produced is limited to 10,000 Spark rows, which can be very conservative depending on usage. This is also controlled by a SQLConf and can be set to unlimited with “spark.sql.execution.arrow.maxRecordsPerBatch=0”. This should be done only if you know your memory limits well as you can easily exceed the JVM memory (see memory optimization improvements below).

Currently, not every Spark data type is supported in the toPandas() conversion. Full type support is underway and I hope to help have it completed by the time Spark 2.3 is release. This is being tracked by the umbrella JIRA SPARK-21187. Date, Timestamp Decimal, Array, Map, and Struct types are not yet supported.

Updated November 8, 2017 to reflect config change and added support for date/timestamp types

Written on August 11, 2017