I was testing compute_metrics to better understand its behavior. My initial understanding was that each GPU processes a subset of the validation dataset and calls the function on its own shard. However, after testing, I noticed that each GPU is receiving the full set of predictions and labels from the entire dataset. Could you please explain why this happens?
GPU Count: 4
Sample Count: 64
Batch Size: 8
Epoch : 2
python -m torch.distributed.launch --nproc_per_node=4 – file.py
code
train_dataset = train_dataset.select(range(64))
val_dataset = val_dataset.select(range(64))
def compute_metrics(eval_pred):
predictions, label_ids = eval_pred
pred_seq = tokenizer.batch_decode(predictions,skip_special_tokens=True)
labels = np.where(label_ids!= -100, label_ids, tokenizer.pad_token_id)
label_seq = tokenizer.batch_decode(labels,skip_special_tokens = True) print(len(label_seq), " Rank ", args.local_rank)
scores = \[
get_global_alignment_score(t,p,aligner) for t,p in zip(label_seq,pred_seq) \]
print(“score lenght”, len(scores))
avg_score = sum(scores) / len(scores)
print(avg_score," avg score")
return {‘GAS’:avg_score}
training_args = Seq2SeqTrainingArguments
(output_dir=f"./finetuning/model/{args.lr}{args.mode}_{args.ext}“,
predict_with_generate=True,
num_train_epochs=args.e, # Number of epochs (epoch-based evaluation)
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
weight_decay=0.01,
eval_strategy=“epoch”, # Evaluate after each epoch
save_strategy=“epoch”, # Save best model based on evaluationload_best_model_at_end=True, # Load best model after trainingmetric_for_best_model=“GAS”, # Choose metric to decide best modelsave_total_limit=2, # Keep only 2 best checkpoints
report_to=[],
gradient_checkpointing=True,
bf16=True,
logging_strategy=“epoch”,
gradient_accumulation_steps=2,
greater_is_better=True, )
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
data_collator=collator,
compute_metrics = compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=10)],)
print(len(label_seq), " Rank ", args.local_rank) - "64 Rank 0"
get_global_alignment_score - executed 512 times in total (2 ep)