Spaces:
Running
Running
bp-level: rewrite snippets for fns revision (single checkpoint, batched score_sequence)
Browse files
demo.html
CHANGED
|
@@ -2046,15 +2046,13 @@ for name, ids in zip(species_prefixes, new_ids):
|
|
| 2046 |
6-mer axis. Reach for bp-level <em>scoring</em> whenever the task is about
|
| 2047 |
a specific base: variant-effect prediction, single-nucleotide mutational
|
| 2048 |
scans, comparing the likelihood of a reference and an alternate allele at
|
| 2049 |
-
one position.
|
| 2050 |
-
|
| 2051 |
-
<code>
|
| 2052 |
-
<code>
|
| 2053 |
-
|
| 2054 |
-
|
| 2055 |
-
|
| 2056 |
-
returns per-base distributions and the probability of the observed base
|
| 2057 |
-
at every position.
|
| 2058 |
</div>
|
| 2059 |
|
| 2060 |
<details class="code-snippet">
|
|
@@ -2065,66 +2063,68 @@ for name, ids in zip(species_prefixes, new_ids):
|
|
| 2065 |
<button class="code-snippet__tab" data-tab="score" type="button">score</button>
|
| 2066 |
</div>
|
| 2067 |
<button class="code-snippet__copy" type="button">Copy</button>
|
| 2068 |
-
<div class="code-snippet__panel active" data-tab="generate"><pre><code>
|
| 2069 |
import torch
|
|
|
|
| 2070 |
|
| 2071 |
-
|
| 2072 |
-
|
| 2073 |
-
|
| 2074 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 2075 |
-
"HuggingFaceBio/Carbon-3B",
|
| 2076 |
-
dtype=torch.bfloat16, device_map="auto",
|
| 2077 |
-
)
|
| 2078 |
|
| 2079 |
-
|
| 2080 |
-
|
| 2081 |
-
|
| 2082 |
-
|
| 2083 |
-
# 6-mer logits to per-base distributions and samples each of the 6
|
| 2084 |
-
# positions independently, then forces the matching 6-mer token. All
|
| 2085 |
-
# standard generation knobs (temperature, top_p, top_k, repetition_penalty)
|
| 2086 |
-
# still apply, they just act on the per-base marginals.
|
| 2087 |
-
out = model.generate(
|
| 2088 |
-
**inputs,
|
| 2089 |
-
max_new_tokens=128, # 128 6-mer tokens ~= 768 bp of continuation
|
| 2090 |
-
custom_generate="HuggingFaceBio/carbon-generate",
|
| 2091 |
trust_remote_code=True,
|
| 2092 |
-
|
| 2093 |
-
|
| 2094 |
-
)
|
| 2095 |
|
| 2096 |
-
|
| 2097 |
-
|
| 2098 |
-
print(tok.decode(new_ids, skip_special_tokens=True))</code></pre></div>
|
| 2099 |
-
<div class="code-snippet__panel" data-tab="score"><pre><code>from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 2100 |
-
import torch, math
|
| 2101 |
|
| 2102 |
-
|
| 2103 |
-
|
| 2104 |
-
|
| 2105 |
-
|
| 2106 |
-
|
| 2107 |
-
|
| 2108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2109 |
model = AutoModelForCausalLM.from_pretrained(
|
| 2110 |
-
|
|
|
|
| 2111 |
trust_remote_code=True,
|
| 2112 |
-
dtype=torch.bfloat16,
|
| 2113 |
-
)
|
|
|
|
|
|
|
|
|
|
| 2114 |
|
| 2115 |
-
|
| 2116 |
-
|
|
|
|
|
|
|
|
|
|
| 2117 |
|
| 2118 |
-
|
| 2119 |
-
# actual: [seq_len] P(observed base | context) at each position
|
| 2120 |
-
bp_probs_ref, actual_ref = model.score_sequence(ref)
|
| 2121 |
-
bp_probs_alt, actual_alt = model.score_sequence(alt)
|
| 2122 |
|
| 2123 |
-
|
| 2124 |
-
|
| 2125 |
-
|
| 2126 |
-
- math.log(actual_ref[20].item() + 1e-12)
|
| 2127 |
-
print(f"log P(alt) - log P(ref) at pos 20: {delta:+.3f}")</code></pre></div>
|
| 2128 |
</div>
|
| 2129 |
</details>
|
| 2130 |
</div>
|
|
|
|
| 2046 |
6-mer axis. Reach for bp-level <em>scoring</em> whenever the task is about
|
| 2047 |
a specific base: variant-effect prediction, single-nucleotide mutational
|
| 2048 |
scans, comparing the likelihood of a reference and an alternate allele at
|
| 2049 |
+
one position. Both paths ship together on the <code>fns</code> revision of
|
| 2050 |
+
the <code>Carbon-3B</code>/<code>8B</code>/<code>500M</code> checkpoints:
|
| 2051 |
+
plain <code>.generate()</code> already produces bp-resolution output (the
|
| 2052 |
+
tokenizer exposes the kmer width as <code>tokenizer.k</code>), and the
|
| 2053 |
+
model gains a <code>score_sequence(seqs)</code> method that batches a list
|
| 2054 |
+
of sequences and returns per-base distributions plus the probability of
|
| 2055 |
+
the observed base at every position.
|
|
|
|
|
|
|
| 2056 |
</div>
|
| 2057 |
|
| 2058 |
<details class="code-snippet">
|
|
|
|
| 2063 |
<button class="code-snippet__tab" data-tab="score" type="button">score</button>
|
| 2064 |
</div>
|
| 2065 |
<button class="code-snippet__copy" type="button">Copy</button>
|
| 2066 |
+
<div class="code-snippet__panel active" data-tab="generate"><pre><code>import math
|
| 2067 |
import torch
|
| 2068 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 2069 |
|
| 2070 |
+
model_id = "HuggingFaceBio/Carbon-3B"
|
| 2071 |
+
revision = "fns"
|
| 2072 |
+
device = "cuda"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2073 |
|
| 2074 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, trust_remote_code=True)
|
| 2075 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 2076 |
+
model_id,
|
| 2077 |
+
revision=revision,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2078 |
trust_remote_code=True,
|
| 2079 |
+
dtype=torch.bfloat16,
|
| 2080 |
+
).to(device).eval()
|
|
|
|
| 2081 |
|
| 2082 |
+
context = "ATGCGCTAGCTACGATCGATCGTAGCTAGCTAGCTAGCTACG"
|
| 2083 |
+
n_bp = 60
|
|
|
|
|
|
|
|
|
|
| 2084 |
|
| 2085 |
+
inputs = tokenizer(f"<dna>{context}", return_tensors="pt", add_special_tokens=False).to(device)
|
| 2086 |
+
|
| 2087 |
+
with torch.no_grad():
|
| 2088 |
+
output_ids = model.generate(
|
| 2089 |
+
**inputs,
|
| 2090 |
+
max_new_tokens=math.ceil(n_bp / tokenizer.k),
|
| 2091 |
+
do_sample=False,
|
| 2092 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 2093 |
+
)
|
| 2094 |
+
|
| 2095 |
+
generated_ids = output_ids[0, inputs.input_ids.shape[1]:]
|
| 2096 |
+
generated_dna = tokenizer.decode(generated_ids, skip_special_tokens=True)[:n_bp]
|
| 2097 |
+
|
| 2098 |
+
print(generated_dna)</code></pre></div>
|
| 2099 |
+
<div class="code-snippet__panel" data-tab="score"><pre><code>import torch
|
| 2100 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 2101 |
+
|
| 2102 |
+
model_id = "HuggingFaceBio/Carbon-3B"
|
| 2103 |
+
revision = "fns"
|
| 2104 |
+
device = "cuda"
|
| 2105 |
+
|
| 2106 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, trust_remote_code=True)
|
| 2107 |
model = AutoModelForCausalLM.from_pretrained(
|
| 2108 |
+
model_id,
|
| 2109 |
+
revision=revision,
|
| 2110 |
trust_remote_code=True,
|
| 2111 |
+
dtype=torch.bfloat16,
|
| 2112 |
+
).to(device).eval()
|
| 2113 |
+
|
| 2114 |
+
reference = "GGGCTATAAAGGCCATCGATCGATCGATCGATCGATCGATCG"
|
| 2115 |
+
perturbed = "GGGCGCGCGCGGCCATCGATCGATCGATCGATCGATCGATCG"
|
| 2116 |
|
| 2117 |
+
# score_sequence accepts a list of sequences and returns, for each one,
|
| 2118 |
+
# the [seq_len, 4] marginal P(A/T/C/G | context) and the [seq_len]
|
| 2119 |
+
# probability of the observed base.
|
| 2120 |
+
with torch.no_grad():
|
| 2121 |
+
bp_probs, actual_probs = model.score_sequence([reference, perturbed])
|
| 2122 |
|
| 2123 |
+
scores = [torch.log(p.clamp_min(1e-12)).mean().item() for p in actual_probs]
|
|
|
|
|
|
|
|
|
|
| 2124 |
|
| 2125 |
+
print(f"reference mean bp logp: {scores[0]:.4f}")
|
| 2126 |
+
print(f"perturbed mean bp logp: {scores[1]:.4f}")
|
| 2127 |
+
print(f"reference preferred: {scores[0] > scores[1]}")</code></pre></div>
|
|
|
|
|
|
|
| 2128 |
</div>
|
| 2129 |
</details>
|
| 2130 |
</div>
|