Salta el contingut

Agregacions a Spark

Introducció

Les agregacions són una de les operacions més freqüents en pipelines de dades: calcular totals, mitjanes, recomptes i estadístiques per grups. Spark ofereix una API molt rica per a agregacions, des de les funcions integrades fins a agregadors definits per l'usuari (UDAF).


groupBy i agg: l'API principal

L'operació groupBy seguida de agg és la forma estàndard d'agregar en Spark:

from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, sum, count, avg, min, max,
    stddev, variance, collect_list, collect_set,
    first, last, countDistinct, approx_count_distinct
)

spark = SparkSession.builder \
    .appName("Agregacions") \
    .master("local[*]") \
    .getOrCreate()

df = spark.read \
    .option("header", "true") \
    .option("inferSchema", "true") \
    .csv("/dades/vendes_2024.csv")

# Agregació bàsica per categoria
resum_per_categoria = df \
    .groupBy("categoria") \
    .agg(
        count("id_comanda").alias("num_comandes"),
        sum("import_total").alias("facturacio"),
        avg("import_total").alias("ticket_mig"),
        min("import_total").alias("import_minim"),
        max("import_total").alias("import_maxim"),
        stddev("import_total").alias("desviacio_tipica"),
        countDistinct("id_client").alias("clients_unics")
    ) \
    .orderBy(col("facturacio").desc())

resum_per_categoria.show()

Agregació multi-nivell

# Agregar per múltiples columnes simultàniament
vendes_mensuals = df \
    .groupBy("any_comanda", "mes_comanda", "categoria") \
    .agg(
        count("id_comanda").alias("num_comandes"),
        sum("quantitat").alias("unitats_venudes"),
        sum("import_total").alias("facturacio"),
        avg("preu_unitari").alias("preu_mig")
    ) \
    .orderBy("any_comanda", "mes_comanda", col("facturacio").desc())

vendes_mensuals.show(20)

Funcions d'agregació avançades

collect_list i collect_set

Permeten recollir tots els valors d'una columna agrupada en una llista o conjunt:

# collect_list: tots els productes per client (inclou duplicats)
comandes_per_client = df \
    .groupBy("id_client") \
    .agg(
        collect_list("producte").alias("productes_comprats"),
        collect_set("categoria").alias("categories_unicas")
    )

comandes_per_client.show(5, truncate=False)
# +----------+------------------------------------+------------------+
# |id_client |productes_comprats                  |categories_unicas |
# +----------+------------------------------------+------------------+
# |1001      |[Monitor, Teclat, Monitor, Ratolí]  |[Perifèrics]      |
# |1002      |[Portàtil, Auriculars]              |[Electrònica, ...]|

approx_count_distinct

Per a cardinalitats molt grans, approx_count_distinct és molt més eficient que countDistinct:

# Recompte aproximat (molt més ràpid per a datasets grans)
df.groupBy("data_comanda") \
    .agg(
        approx_count_distinct("id_client", rsd=0.05).alias("clients_aprox")
    ).show()

Pivot: de files a columnes

L'operació pivot transforma valors únics d'una columna en noves columnes:

# Vendes per categoria i trimestre (pivot)
df_trimesters = df \
    .withColumn("trimestre",
        (((col("mes_comanda") - 1) / 3).cast("int") + 1).cast("string")
    ) \
    .withColumn("trimestre_str",
        col("trimestre").cast("string").substr(1, 1).alias("Q")
    )

pivot_vendes = df \
    .groupBy("categoria") \
    .pivot("mes_comanda", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) \
    .agg(sum("import_total")) \
    .fillna(0)

pivot_vendes.show(truncate=False)
# +------------+----+----+----+----+...
# |categoria   | 1  | 2  | 3  | 4  |...
# +------------+----+----+----+----+...
# |Electrònica |2500|3100|...

Especifica els valors del pivot

Sempre especifica els valors del pivot explícitament (com es fa al codi anterior). Si no, Spark fa una passada extra per descobrir-los, cosa que és molt ineficient en datasets grans.


rollup i cube: agregació jeràrquica

rollup i cube generen subtotals automàticament per a anàlisi multidimensional (OLAP):

from pyspark.sql.functions import grouping_id

# rollup: subtotals jeràrquics (any → mes → categoria)
subtotals = df \
    .rollup("any_comanda", "mes_comanda", "categoria") \
    .agg(sum("import_total").alias("total")) \
    .orderBy("any_comanda", "mes_comanda", "categoria")

subtotals.show(30)
# any=2024, mes=1, cat=Electrònica → subtotal d'Electrònica al gener
# any=2024, mes=1, cat=null       → total del mes de gener
# any=2024, mes=null, cat=null     → total de l'any 2024
# any=null, mes=null, cat=null     → total global

# cube: totes les combinacions de subtotals possibles
tots_els_subtotals = df \
    .cube("any_comanda", "categoria") \
    .agg(sum("import_total").alias("total")) \
    .orderBy("any_comanda", "categoria")

tots_els_subtotals.show()

Agregacions amb condicions (sum condicional)

from pyspark.sql.functions import when, sum as spark_sum

# Calcular imports per segment de preu en la mateixa passada
resum = df \
    .groupBy("categoria") \
    .agg(
        spark_sum(when(col("segment_preu") == "alt",  col("import_total")).otherwise(0)).alias("import_alt"),
        spark_sum(when(col("segment_preu") == "mig",  col("import_total")).otherwise(0)).alias("import_mig"),
        spark_sum(when(col("segment_preu") == "baix", col("import_total")).otherwise(0)).alias("import_baix"),
        spark_sum("import_total").alias("import_total")
    )

resum.show()

User-Defined Aggregate Functions (UDAF)

Per a lògica d'agregació personalitzada que no es pot expressar amb les funcions integrades:

from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType
import pandas as pd

# UDAF amb Pandas UDF (recomanat a Spark 3.x — molt més eficient que les UDAF clàssiques)
@pandas_udf(DoubleType())
def coeficient_variacio(series: pd.Series) -> float:
    """Calcula el coeficient de variació (stddev / mean * 100)"""
    if series.mean() == 0:
        return 0.0
    return float(series.std() / series.mean() * 100)

# Usar-lo en una agregació
variabilitat = df \
    .groupBy("categoria") \
    .agg(
        coeficient_variacio("import_total").alias("cv_import")
    )

variabilitat.show()

Miniactivitat: Anàlisi de vendes multidimensional

Carrega un dataset de vendes (pots crear-lo amb spark.createDataFrame) amb les columnes: data, regio, producte, import.

  1. Calcula la facturació total per regio i producte.
  2. Fes un pivot per mostrar la facturació mensual per regio (les regions com a files, els mesos com a columnes).
  3. Usa rollup per obtenir subtotals per regio i totals globals.
  4. Identifica quina regio té el coeficient de variació més alt en les seves vendes.