2
votes

Comment obtenir des lots de lignes de Spark à l'aide de pyspark

J'ai un RDD Spark de plus de 6 milliards de lignes de données que je souhaite utiliser pour entraîner un modèle d'apprentissage en profondeur, à l'aide de train_on_batch. Je ne peux pas mettre toutes les lignes en mémoire, donc je voudrais obtenir environ 10K à la fois pour les regrouper en morceaux de 64 ou 128 (selon la taille du modèle). J'utilise actuellement rdd.sample () mais je ne pense pas que cela garantit que j'obtiendrai toutes les lignes. Existe-t-il une meilleure méthode pour partitionner les données pour les rendre plus gérables afin que je puisse écrire une fonction de générateur pour obtenir des lots? Mon code est ci-dessous:

data_df = spark.read.parquet(PARQUET_FILE)
print(f'RDD Count: {data_df.count()}') # 6B+
data_sample = data_df.sample(True, 0.0000015).take(6400) 
sample_df = data_sample.toPandas()

def get_batch():
  for row in sample_df.itertuples():
    # TODO: put together a batch size of BATCH_SIZE
    yield row

for i in range(10):
    print(next(get_batch()))


4 commentaires

Juste un avertissement, votre dénomination est erronée, ce n'est pas un rdd qui est un dataframe.


De plus, je ne pense pas que vous n'ayez rien à gagner en itérant sur Spark en utilisant des pandas, ce serait mieux si vous lisiez ceci en python en morceaux.


J'ai changé la dénomination pour refléter qu'il s'agit d'un dataframe. J'ai besoin des données en tant que dataframe Pandas pour alimenter un modèle existant.


Je voudrais demander pourquoi la question a été rejetée. J'ai passé beaucoup de temps à chercher sur Google comment faire cela. Si la réponse est évidente ou s'il existe un mécanisme documenté / bien compris, j'apprécierais un lien et je m'excuserai pour la perte de temps.


3 Réponses :


1
votes

Je ne crois pas que Spark vous permette de compenser ou de paginer vos données.

Mais vous pouvez ajouter un index, puis paginer dessus, d'abord:

    from pyspark.sql.functions import lit
    data_df = spark.read.parquet(PARQUET_FILE)
    count = data_df.count()
    chunk_size = 10000

    # Just adding a column for the ids
    df_new_schema = data_df.withColumn('pres_id', lit(1))

    # Adding the ids to the rdd 
    rdd_with_index = data_df.rdd.zipWithIndex().map(lambda (row,rowId): (list(row) + [rowId+1]))

    # Creating a dataframe with index
    df_with_index = spark.createDataFrame(chunk_rdd,schema=df_new_schema.schema)

    # Iterating into the chunks
    for chunk_size in range(0,count+1 ,chunk_size):
        initial_page = page_num*chunk_size
        final_page = initial_page + chunk_size 
        where_query = ('pres_id > {0} and pres_id <= {1}').format(initial_page,final_page)
        chunk_df = df_with_index.where(where_query).toPandas()
        train_on_batch(chunk_df) # <== Your function here        

Ce n'est pas optimal, cela exploitera mal l'étincelle en raison de l'utilisation d'un dataframe pandas mais résoudra votre problème.

N'oubliez pas de supprimer l'identifiant si cela affecte votre fonction.


0 commentaires

0
votes

Essaye ça:

 from pyspark.sql import functions as F
 sample_dict = {}

 # Read the parquet file
 df = spark.read.parquet("parquet file")

 # add the partition_number as a column
 df = df.withColumn('partition_num', F.spark_partition_id())
 df.persist()

 total_partition = [int(row.partition_num) for row in 
 df.select('partition_num').distinct().collect()]

 for each_df in total_partition:
     sample_dict[each_df] = df.where(df.partition_num == each_df) 


0 commentaires

0
votes

Je vois que vous prévoyez de former un modèle de Deep Learning. Jetez un œil à la bibliothèque open-source Petastorm qui a été créée exactement pour ce cas d'utilisation.

https://docs.databricks.com/applications/machine-learning/load-data/petastorm.html

Petastorm est une bibliothèque d'accès aux données open source. Cette bibliothèque permet un entraînement et une évaluation à nœud unique ou distribué de modèles d'apprentissage en profondeur directement à partir d'ensembles de données au format Apache Parquet et d'ensembles de données déjà chargés en tant que DataFrames Apache Spark. Petastorm prend en charge les frameworks d'apprentissage automatique (ML) basés sur Python tels que Tensorflow, PyTorch et PySpark. Pour plus d'informations sur Petastorm, reportez-vous à la page Petastorm GitHub et à la documentation de l'API Petastorm .


0 commentaires