|
|
| # BaseTrainer |
|
|
| ## π Trained With [EasyDeL](https://github.com/erfanzar/EasyDeL) |
|
|
| EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning |
| models. With a primary focus on Jax, EasyDeL aims to provide convenient and effective solutions for |
| training Flax/Jax models on TPU/GPU, for both serving and training purposes. |
|
|
| ## π¦ Installation & Usage |
| |
| ```python |
| from easydel import AutoEasyDeLModelForCausalLM |
| from jax import numpy as jnp, lax |
| |
| model = AutoEasyDeLModelForCausalLM.from_pretrained( |
| f"REPO_ID/BaseTrainer", |
| dtype=..., |
| param_dtype=..., |
| precision=lax.Precision("fastest"), |
| auto_shard_model=True, |
| ) |
| ``` |
|
|
| ## π§ Training Configuration |
|
|
| ### Model Details |
| - **Architecture**: gemma3_text |
| - **Platform**: TPU |
| - **Number of Devices**: 16 |
| |
| ### Training Parameters |
| - **Learning Rate**: 4e-05 β 4e-06 |
| - **Optimizer**: adamw |
| - **Scheduler**: cosine |
| - **Warmup Steps**: 50 |
| - **Weight Decay**: 0.02 |
| - **Loss Config**: LossConfig( |
| ignore_index : -100 |
| label_smoothing : 0.0 |
| z_loss : 0.0 |
| loss_normalizing_factor : NUM_REAL_TARGET_TOKENS |
| num_labels : None |
| problem_type : None |
| divide_weight_sum : False |
| shift_tokens : True |
| break_on_nan : True |
| reduction : None |
| num_classification_labels : None |
| classification_problem_type : None |
| ) |
|
|
| ### Training Setup |
| - **Epochs**: 3 |
| - **Batch Size**: 8 |
| - **Sequence Length**: 8192 |
| - **Dtype**: <class 'jax.numpy.bfloat16'> |
| - **Params Dtype**: <class 'jax.numpy.bfloat16'> |
|
|
| ### Advanced Configuration |
| - **Gradient Checkpointing**: |
| - **Gradient Accumulation Steps**: 1 |
| - **Max Training Steps**: None |
| - **Max Evaluation Steps**: None |
| - **Training Duration**: 7H |
|
|
| ### Sharding Configuration |
| ```python |
| # Partition Rules |
| ( ('model/embed_tokens/embedding', PartitionSpec(('fsdp', 'sp'), 'tp')), |
| ('self_attn/q_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))), |
| ('self_attn/k_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))), |
| ('self_attn/v_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))), |
| ('self_attn/o_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')), |
| ('mlp/gate_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')), |
| ('mlp/up_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')), |
| ('mlp/down_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))), |
| ('input_layernorm/kernel', PartitionSpec(None,)), |
| ('post_attention_layernorm/kernel', PartitionSpec(None,)), |
| ('pre_feedforward_layernorm/kernel', PartitionSpec(None,)), |
| ('post_feedforward_layernorm/kernel', PartitionSpec(None,)), |
| ('model/norm/kernel', PartitionSpec(None,)), |
| ('lm_head/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')), |
| ('.*', PartitionSpec(None,))) |
| ``` |
|
|
| --- |
| *Generated with EasyDeL v0.1.3* |
|
|