Instructions to use Ex0bit/jit-lora with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use Ex0bit/jit-lora with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir jit-lora Ex0bit/jit-lora
- Notebooks
- Google Colab
- Kaggle
- Local Apps
- LM Studio
| """ | |
| neural_data.py — Training data manager for MLX LoRA fine-tuning. | |
| Manages a rolling buffer of recent conversation turns and a persistent | |
| replay buffer for anti-catastrophic-forgetting experience replay. | |
| """ | |
| import json | |
| import random | |
| import time | |
| from collections import deque | |
| from pathlib import Path | |
| from typing import Optional | |
| class TrainingExample: | |
| """A single training example (conversation turn).""" | |
| __slots__ = ("messages", "timestamp", "token_count", "session_id") | |
| def __init__(self, messages: list[dict], timestamp: float = 0, | |
| token_count: int = 0, session_id: str = ""): | |
| self.messages = messages | |
| self.timestamp = timestamp or time.time() | |
| self.token_count = token_count | |
| self.session_id = session_id | |
| def to_dict(self) -> dict: | |
| return { | |
| "messages": self.messages, | |
| "timestamp": self.timestamp, | |
| "token_count": self.token_count, | |
| "session_id": self.session_id, | |
| } | |
| def from_dict(cls, d: dict) -> "TrainingExample": | |
| return cls( | |
| messages=d["messages"], | |
| timestamp=d.get("timestamp", 0), | |
| token_count=d.get("token_count", 0), | |
| session_id=d.get("session_id", ""), | |
| ) | |
| class TrainingDataManager: | |
| """Manages rolling buffer + persistent replay for LoRA training.""" | |
| def __init__(self, rolling_size: int = 100, replay_size: int = 500, | |
| replay_path: str = "", min_response_tokens: int = 10): | |
| self.rolling_size = rolling_size | |
| self.replay_size = replay_size | |
| self.min_response_tokens = min_response_tokens | |
| self.replay_path = replay_path | |
| self._rolling: deque[TrainingExample] = deque(maxlen=rolling_size) | |
| self._replay: list[TrainingExample] = [] | |
| self._total_added = 0 | |
| if replay_path: | |
| self._load_replay() | |
| def rolling_count(self) -> int: | |
| return len(self._rolling) | |
| def replay_count(self) -> int: | |
| return len(self._replay) | |
| def total_added(self) -> int: | |
| return self._total_added | |
| def add_turn(self, user_text: str, assistant_text: str, | |
| system_prompt: str = "", session_id: str = "") -> bool: | |
| """Add a conversation turn to the training buffer. | |
| Returns True if the example was accepted (not filtered). | |
| """ | |
| # Quality filter: skip short/empty responses | |
| approx_tokens = len(assistant_text.split()) | |
| if approx_tokens < self.min_response_tokens: | |
| return False | |
| # Skip tool-only or empty content | |
| if not assistant_text.strip(): | |
| return False | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| messages.append({"role": "user", "content": user_text}) | |
| messages.append({"role": "assistant", "content": assistant_text}) | |
| example = TrainingExample( | |
| messages=messages, | |
| token_count=approx_tokens, | |
| session_id=session_id, | |
| ) | |
| self._rolling.append(example) | |
| self._total_added += 1 | |
| # Add to replay with reservoir sampling | |
| if len(self._replay) < self.replay_size: | |
| self._replay.append(example) | |
| else: | |
| idx = random.randint(0, self._total_added - 1) | |
| if idx < self.replay_size: | |
| self._replay[idx] = example | |
| return True | |
| def get_training_batch(self, batch_size: int = 1, | |
| replay_ratio: float = 0.3) -> list[TrainingExample]: | |
| """Get a training batch mixing recent and replay examples. | |
| Args: | |
| batch_size: Total examples in batch. 0 = all available data. | |
| replay_ratio: Fraction of batch from replay buffer (0.0-1.0) | |
| Returns: | |
| List of TrainingExample | |
| """ | |
| if not self._rolling: | |
| return [] | |
| # batch_size <= 0 means "all available data" | |
| if batch_size <= 0: | |
| batch = list(self._rolling) | |
| if self._replay: | |
| # Add replay examples not already in rolling | |
| rolling_set = {id(ex) for ex in self._rolling} | |
| for ex in self._replay: | |
| if id(ex) not in rolling_set: | |
| batch.append(ex) | |
| random.shuffle(batch) | |
| return batch | |
| n_replay = int(batch_size * replay_ratio) | |
| n_recent = batch_size - n_replay | |
| batch = [] | |
| # Recent examples (most recent first) | |
| recent = list(self._rolling) | |
| if n_recent > 0: | |
| recent_sample = recent[-n_recent:] if len(recent) >= n_recent else recent | |
| batch.extend(recent_sample) | |
| # Replay examples (random sample) | |
| if n_replay > 0 and self._replay: | |
| replay_sample = random.sample( | |
| self._replay, | |
| min(n_replay, len(self._replay)) | |
| ) | |
| batch.extend(replay_sample) | |
| random.shuffle(batch) | |
| return batch | |
| def get_recent(self, n: int = 5) -> list[TrainingExample]: | |
| """Get the N most recent training examples.""" | |
| return list(self._rolling)[-n:] | |
| def save_rolling(self, path: str = ""): | |
| """Save rolling buffer to disk.""" | |
| path = path or str(Path(self.replay_path).parent / "buffer.jsonl") | |
| Path(path).parent.mkdir(parents=True, exist_ok=True) | |
| with open(path, "w") as f: | |
| for ex in self._rolling: | |
| f.write(json.dumps(ex.to_dict()) + "\n") | |
| def load_rolling(self, path: str = ""): | |
| """Load rolling buffer from disk.""" | |
| path = path or str(Path(self.replay_path).parent / "buffer.jsonl") | |
| if not Path(path).exists(): | |
| return | |
| self._rolling.clear() | |
| with open(path) as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| ex = TrainingExample.from_dict(json.loads(line)) | |
| self._rolling.append(ex) | |
| def save_replay(self): | |
| """Persist replay buffer to disk.""" | |
| if not self.replay_path: | |
| return | |
| Path(self.replay_path).parent.mkdir(parents=True, exist_ok=True) | |
| with open(self.replay_path, "w") as f: | |
| for ex in self._replay: | |
| f.write(json.dumps(ex.to_dict()) + "\n") | |
| def _load_replay(self): | |
| """Load replay buffer from disk.""" | |
| if not self.replay_path or not Path(self.replay_path).exists(): | |
| return | |
| self._replay.clear() | |
| with open(self.replay_path) as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| ex = TrainingExample.from_dict(json.loads(line)) | |
| self._replay.append(ex) | |
| # Trim to max size | |
| if len(self._replay) > self.replay_size: | |
| self._replay = random.sample(self._replay, self.replay_size) | |
| def clear(self): | |
| """Clear all buffers (for reset).""" | |
| self._rolling.clear() | |
| self._replay.clear() | |
| self._total_added = 0 | |
| def stats(self) -> dict: | |
| """Return buffer statistics.""" | |
| return { | |
| "rolling_count": self.rolling_count, | |
| "rolling_capacity": self.rolling_size, | |
| "replay_count": self.replay_count, | |
| "replay_capacity": self.replay_size, | |
| "total_added": self._total_added, | |
| } | |