|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| Training Data Loader for Language Model Training
|
|
|
| This module provides efficient data loading and batching for training GPT-style
|
| language models. It handles text preprocessing, tokenization, and creates
|
| batches suitable for autoregressive language modeling.
|
|
|
| FEATURES:
|
| - Memory-efficient text loading with sliding window
|
| - Automatic tokenization using trained SentencePiece model
|
| - Configurable sequence length and batch size
|
| - CPU-optimized data loading for limited hardware
|
| - Support for training data validation and statistics
|
|
|
| MEMORY OPTIMIZATION:
|
| - Streaming data loading (doesn't load entire dataset to memory)
|
| - Configurable chunk sizes for large files
|
| - Efficient tensor creation and batching
|
| - Garbage collection hints for memory management
|
|
|
| Usage:
|
| from data_loader import TextDataLoader
|
|
|
| loader = TextDataLoader(
|
| data_file="data/clean/training_data.txt",
|
| tokenizer_path="data/tokenizer/tokenizer.model",
|
| seq_len=512,
|
| batch_size=4
|
| )
|
|
|
| for batch in loader:
|
| input_ids, targets = batch
|
| # input_ids: (batch_size, seq_len)
|
| # targets: (batch_size, seq_len) - shifted by 1 for next token prediction
|
|
|
| Author: Louis Chua Bean Chong
|
| License: GPLv3
|
| """
|
|
|
| import os
|
| import gc
|
| import random
|
| import torch
|
| import time
|
| from typing import Iterator, Tuple, List, Optional
|
| from pathlib import Path
|
|
|
| try:
|
| import sentencepiece as spm
|
| except ImportError:
|
| print("ERROR: SentencePiece not installed. Run: pip install sentencepiece")
|
| exit(1)
|
|
|
|
|
| class TextDataLoader:
|
| """
|
| Efficient data loader for autoregressive language model training.
|
|
|
| This class handles loading text data, tokenizing it using SentencePiece,
|
| and creating batches suitable for next-token prediction training.
|
| """
|
|
|
| def __init__(
|
| self,
|
| data_file: str,
|
| tokenizer_path: str,
|
| seq_len: int = 512,
|
| batch_size: int = 4,
|
| chunk_size: int = 1000000,
|
| shuffle: bool = True,
|
| seed: int = 42
|
| ):
|
| """
|
| Initialize the data loader.
|
|
|
| Args:
|
| data_file: Path to training text file (one passage per line)
|
| tokenizer_path: Path to trained SentencePiece model
|
| seq_len: Maximum sequence length for training
|
| batch_size: Batch size for training
|
| chunk_size: Number of lines to read in memory at once
|
| shuffle: Whether to shuffle training examples
|
| seed: Random seed for reproducibility
|
| """
|
| self.data_file = data_file
|
| self.tokenizer_path = tokenizer_path
|
| self.seq_len = seq_len
|
| self.batch_size = batch_size
|
| self.chunk_size = chunk_size
|
| self.shuffle = shuffle
|
| self.seed = seed
|
|
|
|
|
| self._validate_inputs()
|
|
|
|
|
| self.tokenizer = self._load_tokenizer()
|
|
|
|
|
| self.total_lines = self._count_lines()
|
| self.current_line = 0
|
|
|
|
|
| random.seed(seed)
|
|
|
| print(f"๐ TextDataLoader initialized")
|
| print(f" Data file: {data_file}")
|
| print(f" Total passages: {self.total_lines:,}")
|
| print(f" Sequence length: {seq_len}")
|
| print(f" Batch size: {batch_size}")
|
| print(f" Vocabulary size: {self.tokenizer.vocab_size():,}")
|
|
|
| def _validate_inputs(self) -> None:
|
| """Validate input parameters and file paths."""
|
| if not os.path.exists(self.data_file):
|
| raise FileNotFoundError(f"Training data file not found: {self.data_file}")
|
|
|
| if not os.path.exists(self.tokenizer_path):
|
| raise FileNotFoundError(f"Tokenizer model not found: {self.tokenizer_path}")
|
|
|
| if self.seq_len <= 0:
|
| raise ValueError(f"Sequence length must be positive, got {self.seq_len}")
|
|
|
| if self.batch_size <= 0:
|
| raise ValueError(f"Batch size must be positive, got {self.batch_size}")
|
|
|
| if self.chunk_size <= 0:
|
| raise ValueError(f"Chunk size must be positive, got {self.chunk_size}")
|
|
|
| def _load_tokenizer(self) -> spm.SentencePieceProcessor:
|
| """Load the trained SentencePiece tokenizer."""
|
| try:
|
| tokenizer = spm.SentencePieceProcessor()
|
| tokenizer.load(self.tokenizer_path)
|
| return tokenizer
|
| except Exception as e:
|
| raise RuntimeError(f"Failed to load tokenizer: {e}")
|
|
|
| def _count_lines(self) -> int:
|
| """Count total number of lines in the data file."""
|
| print("๐ Counting training passages...")
|
| start_time = time.time()
|
|
|
| line_count = 0
|
| with open(self.data_file, 'r', encoding='utf-8') as f:
|
| for line in f:
|
| if line.strip():
|
| line_count += 1
|
|
|
| count_time = time.time() - start_time
|
| print(f"โ Found {line_count:,} passages in {count_time:.1f}s")
|
|
|
| return line_count
|
|
|
| def _read_chunk(self, start_line: int = 0) -> List[str]:
|
| """
|
| Read a chunk of lines from the data file.
|
|
|
| Args:
|
| start_line: Line number to start reading from
|
|
|
| Returns:
|
| List of text passages
|
| """
|
| chunk = []
|
| current_line = 0
|
| lines_read = 0
|
|
|
| with open(self.data_file, 'r', encoding='utf-8') as f:
|
| for line in f:
|
| if current_line < start_line:
|
| current_line += 1
|
| continue
|
|
|
| text = line.strip()
|
| if text:
|
| chunk.append(text)
|
| lines_read += 1
|
|
|
| if lines_read >= self.chunk_size:
|
| break
|
|
|
| current_line += 1
|
|
|
| return chunk
|
|
|
| def _tokenize_texts(self, texts: List[str]) -> List[List[int]]:
|
| """
|
| Tokenize a list of text passages using SentencePiece tokenizer.
|
|
|
| This method converts raw text into token ID sequences suitable for language model training.
|
| It handles special tokens (BOS/EOS) and length constraints for efficient training.
|
|
|
| Text processing pipeline:
|
| 1. Add BOS (Beginning of Sequence) token to mark sequence start
|
| 2. Tokenize text using trained SentencePiece model (subword tokenization)
|
| 3. Truncate sequences that exceed maximum length
|
| 4. Add EOS (End of Sequence) token to mark sequence end
|
|
|
| Special token handling:
|
| - BOS token helps model learn to generate text from scratch
|
| - EOS token signals natural sequence endings
|
| - These tokens are crucial for proper autoregressive generation
|
|
|
| Args:
|
| texts: List of text passages (typically Wikipedia passages from SQUAD)
|
| Each passage should be a complete, coherent text segment
|
|
|
| Returns:
|
| List of token ID sequences, where each sequence is a list of integers
|
| representing subword tokens from the SentencePiece vocabulary
|
| """
|
| tokenized = []
|
|
|
| for text in texts:
|
| try:
|
|
|
|
|
|
|
| tokens = [self.tokenizer.bos_id()] + self.tokenizer.encode(text)
|
|
|
|
|
|
|
|
|
| if len(tokens) > self.seq_len - 1:
|
| tokens = tokens[:self.seq_len - 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
| tokens.append(self.tokenizer.eos_id())
|
|
|
|
|
| if len(tokens) <= 2:
|
| print(f"โ ๏ธ Skipping very short text: {text[:50]}...")
|
| continue
|
|
|
| tokenized.append(tokens)
|
|
|
| except Exception as e:
|
|
|
|
|
| print(f"โ ๏ธ Failed to tokenize passage: {text[:50]}... Error: {e}")
|
| continue
|
|
|
|
|
| if tokenized:
|
| avg_length = sum(len(tokens) for tokens in tokenized) / len(tokenized)
|
| print(f"๐ Tokenized {len(tokenized)} passages, avg length: {avg_length:.1f} tokens")
|
|
|
| return tokenized
|
|
|
| def _create_training_examples(self, token_sequences: List[List[int]]) -> List[Tuple[List[int], List[int]]]:
|
| """
|
| Create training examples with input and target sequences.
|
|
|
| For autoregressive training, targets are inputs shifted by one position.
|
|
|
| Args:
|
| token_sequences: List of tokenized sequences
|
|
|
| Returns:
|
| List of (input_ids, target_ids) tuples
|
| """
|
| examples = []
|
|
|
| for tokens in token_sequences:
|
| if len(tokens) < 2:
|
| continue
|
|
|
|
|
| if len(tokens) > self.seq_len:
|
|
|
| stride = self.seq_len // 2
|
| for i in range(0, len(tokens) - self.seq_len, stride):
|
| input_ids = tokens[i:i + self.seq_len]
|
| target_ids = tokens[i + 1:i + self.seq_len + 1]
|
| examples.append((input_ids, target_ids))
|
| else:
|
|
|
| input_ids = tokens[:-1]
|
| target_ids = tokens[1:]
|
|
|
|
|
| while len(input_ids) < self.seq_len:
|
| input_ids.append(self.tokenizer.pad_id())
|
| target_ids.append(-1)
|
|
|
|
|
| input_ids = input_ids[:self.seq_len]
|
| target_ids = target_ids[:self.seq_len]
|
|
|
| examples.append((input_ids, target_ids))
|
|
|
| return examples
|
|
|
| def _create_batch(self, examples: List[Tuple[List[int], List[int]]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| """
|
| Create a batch tensor from training examples.
|
|
|
| Args:
|
| examples: List of (input_ids, target_ids) tuples
|
|
|
| Returns:
|
| Tuple of (input_tensor, target_tensor)
|
| """
|
| if not examples:
|
| raise ValueError("Cannot create batch from empty examples")
|
|
|
| batch_size = len(examples)
|
|
|
|
|
| input_ids = torch.zeros((batch_size, self.seq_len), dtype=torch.long)
|
| target_ids = torch.full((batch_size, self.seq_len), -1, dtype=torch.long)
|
|
|
|
|
| for i, (inp, tgt) in enumerate(examples):
|
| input_ids[i, :len(inp)] = torch.tensor(inp, dtype=torch.long)
|
| target_ids[i, :len(tgt)] = torch.tensor(tgt, dtype=torch.long)
|
|
|
| return input_ids, target_ids
|
|
|
| def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
|
| """
|
| Iterate over training batches.
|
|
|
| Yields:
|
| Tuple of (input_ids, target_ids) tensors
|
| """
|
| self.current_line = 0
|
|
|
| while self.current_line < self.total_lines:
|
|
|
| texts = self._read_chunk(self.current_line)
|
| if not texts:
|
| break
|
|
|
|
|
| token_sequences = self._tokenize_texts(texts)
|
|
|
|
|
| examples = self._create_training_examples(token_sequences)
|
|
|
|
|
| if self.shuffle:
|
| random.shuffle(examples)
|
|
|
|
|
| for i in range(0, len(examples), self.batch_size):
|
| batch_examples = examples[i:i + self.batch_size]
|
|
|
| if len(batch_examples) == self.batch_size:
|
| try:
|
| input_ids, target_ids = self._create_batch(batch_examples)
|
| yield input_ids, target_ids
|
| except Exception as e:
|
| print(f"โ ๏ธ Failed to create batch: {e}")
|
| continue
|
|
|
|
|
| self.current_line += len(texts)
|
|
|
|
|
| del texts, token_sequences, examples
|
| gc.collect()
|
|
|
| def get_data_stats(self) -> dict:
|
| """
|
| Get statistics about the training data.
|
|
|
| Returns:
|
| Dictionary with data statistics
|
| """
|
| print("๐ Analyzing training data...")
|
|
|
|
|
| sample_texts = self._read_chunk(0)[:100]
|
| token_sequences = self._tokenize_texts(sample_texts)
|
|
|
| if token_sequences:
|
| sequence_lengths = [len(seq) for seq in token_sequences]
|
| avg_length = sum(sequence_lengths) / len(sequence_lengths)
|
| max_length = max(sequence_lengths)
|
| min_length = min(sequence_lengths)
|
| else:
|
| avg_length = max_length = min_length = 0
|
|
|
|
|
| estimated_total_tokens = int(avg_length * self.total_lines)
|
|
|
|
|
| examples_per_passage = max(1, avg_length // self.seq_len)
|
| total_examples = int(self.total_lines * examples_per_passage)
|
| batches_per_epoch = total_examples // self.batch_size
|
|
|
| stats = {
|
| "total_passages": self.total_lines,
|
| "avg_tokens_per_passage": avg_length,
|
| "min_tokens_per_passage": min_length,
|
| "max_tokens_per_passage": max_length,
|
| "estimated_total_tokens": estimated_total_tokens,
|
| "estimated_examples_per_epoch": total_examples,
|
| "estimated_batches_per_epoch": batches_per_epoch,
|
| "sequence_length": self.seq_len,
|
| "batch_size": self.batch_size,
|
| "vocabulary_size": self.tokenizer.vocab_size()
|
| }
|
|
|
| print(f"โ Data analysis complete:")
|
| print(f" Total passages: {stats['total_passages']:,}")
|
| print(f" Avg tokens per passage: {stats['avg_tokens_per_passage']:.1f}")
|
| print(f" Estimated total tokens: {stats['estimated_total_tokens']:,}")
|
| print(f" Estimated batches per epoch: {stats['estimated_batches_per_epoch']:,}")
|
|
|
| return stats
|
|
|
|
|
| def test_data_loader():
|
| """Test function for the data loader."""
|
| print("๐งช Testing TextDataLoader...")
|
|
|
|
|
| try:
|
| loader = TextDataLoader(
|
| data_file="data/clean/training_data.txt",
|
| tokenizer_path="data/tokenizer/tokenizer.model",
|
| seq_len=128,
|
| batch_size=2,
|
| chunk_size=10
|
| )
|
|
|
|
|
| stats = loader.get_data_stats()
|
|
|
|
|
| print("\n๐ Testing batch iteration...")
|
| start_time = time.time()
|
| batch_count = 0
|
|
|
| for batch_idx, (input_ids, target_ids) in enumerate(loader):
|
| batch_count += 1
|
|
|
| print(f"Batch {batch_idx + 1}:")
|
| print(f" Input shape: {input_ids.shape}")
|
| print(f" Target shape: {target_ids.shape}")
|
| print(f" Sample input tokens: {input_ids[0][:10].tolist()}")
|
| print(f" Sample target tokens: {target_ids[0][:10].tolist()}")
|
|
|
| if batch_idx >= 2:
|
| break
|
|
|
| test_time = time.time() - start_time
|
| print(f"\nโ Data loader test completed successfully!")
|
| print(f" Processed {batch_count} batches in {test_time:.2f}s")
|
| print(f" Average time per batch: {test_time/max(1, batch_count):.2f}s")
|
|
|
| return True
|
|
|
| except Exception as e:
|
| print(f"โ Data loader test failed: {e}")
|
| import traceback
|
| traceback.print_exc()
|
| return False
|
|
|
|
|
| if __name__ == "__main__":
|
| test_data_loader() |