Getting Info About Spark Partitions

Often getting information about Spark partitions is essential when tuning performance. All the samples are in python.

Partition Count

Getting number of partitions of a DataFrame is easy, but none of the members are part of DF class itself and you need to call to .rdd. Any of the following three lines will work:



In: rbe_s1.rdd.getNumPartitions()
Out: 13

But, how do I know a bit more about partitions, i.e. at least their sizes and what was involved in partitioning the data? I guess one could write data down to disk and inspect files, but that’s tedious and often won’t work in high security environments.

Partition Sizes

Getting a partition size is also not obvious, and there is not built-in function to do that. Again, one can do that with low-level RDD API, for instance .mapPartitions which is defined as follows:

def mapPartitions(self, f, preservesPartitioning=False):
        Return a new RDD by applying a function to each partition of this RDD.

        >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
        >>> def f(iterator): yield sum(iterator)
        >>> rdd.mapPartitions(f).collect()
        [3, 7]
    def func(s, iterator):
        return f(iterator)
    return self.mapPartitionsWithIndex(func, preservesPartitioning)

The example in documentation comment sort of gives it away already. Getting a list of partition sizes:

lengths = rdd.mapPartitions(get_partition_len, True).collect()

Utility Function

Putting it all together, here is a helper function that displays basic DataFrame statistics:

from pyspark import RDD
from pyspark.sql import DataFrame

def print_partition_info(df: DataFrame):

    import statistics

    def get_partition_len(iterator):
        yield sum(1 for _ in iterator)

    rdd: RDD = df.rdd

    count = rdd.getNumPartitions()
    # lengths = rdd.glom().map(len).collect() # much more memory hungry than next line
    lengths = rdd.mapPartitions(get_partition_len, True).collect()

    print(f"{count} partition(s) total.")

    print(f"size stats")
    print(f"     min: {min(lengths)}")
    print(f"     max: {max(lengths)}")
    print(f"     avg: {sum(lengths)/len(lengths)}")
    print(f"  stddev: {statistics.stdev(lengths)}")

    print("detailed info")
    for i, pl in enumerate(lengths):
        print(f"  {i}. {pl}")

Sample output:

5 partition(s) total.
size stats
     min: 13
     max: 4403
     avg: 1277.4
  stddev: 1929.5741239973136
detailed info
  0. 4403
  1. 1914
  2. 38
  3. 19
  4. 13

As you can see partition 0 has most of the data, so it’s definitely going to screw things up, or already does.

Repartitioning Data

Now that it’s (hopefully) clear which partitions are the bad boys, you might want to re-partition the dataframe. This part is move obvious comparing to the before. Basically there are two functions - coalesce and repartition on the DF itself. The documentation for them is very similar and it’s really confusing what to use when:


Changes DF partitioning, but actually doesn’t do what it says on the tin. Coalesce does not physically repartition data but rather changes number of partitions. This means that some partitions claim ownership of others to reach the requested number of partitions. For instance, if you repartition DataFrame

  0. 4403
  1. 1914
  2. 38
  3. 19
  4. 13

to 2 partitions (.coalesce(2)) you will get:

  0. 4454
  1. 1933

so yeah, you did get 2 partitions, but that didn’t make much of a different to performance, as one of the partitions is more than 2 times bigger than the other. Note that coalesce does not shuffle data so there is nothing actually happening physically. It’s really useful in many cases where you don’t want the data to be moved, but want to process it sequentially with some parallelism involved.

Repartition Function

This function does repartition data shuffling it between the nodes and does physically move data around. Calling repartition(2) on the DataFrame above results in the following:

  0. 3198
  1. 3195

As you can see, partitions are almost equal in size!

You can also supply columns as function arguments, and they will be used to calculate resulting hash for the partition. This can be useful if your calculation take into account particular columns and work best with similar values to be close to each other. The function is defined as follows:

    def repartition(self, numPartitions, *cols):

and cols is actually a Column type, meaning you can pass either column name, or any expression that results in a column. This is particularly important, as by supplying an expression you are essentially creating a partition hashing function. So it’s not limited just to a dumb column name.

Have a question⁉ Contact me.