esunAI commited on
Commit
bc21134
Β·
verified Β·
1 Parent(s): 164b12c

Add decode_and_test_sequences.py

Browse files
Files changed (1) hide show
  1. src/decode_and_test_sequences.py +202 -0
src/decode_and_test_sequences.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Decode all 80 generated sequences and test them with HMD-AMP.
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ import pandas as pd
9
+ from Bio import SeqIO
10
+ from Bio.SeqRecord import SeqRecord
11
+ from Bio.Seq import Seq
12
+ import os
13
+ from datetime import datetime
14
+ from tqdm import tqdm
15
+ import sys
16
+
17
+ # Import the decoder
18
+ from final_sequence_decoder import EmbeddingToSequenceConverter
19
+
20
+ # Import HMD-AMP components
21
+ sys.path.append('/home/edwardsun/flow/HMD-AMP')
22
+ from sklearn.utils import shuffle
23
+ import esm
24
+ from deepforest import CascadeForestClassifier
25
+ from src.utils import *
26
+
27
+ def load_generated_embeddings():
28
+ """Load all generated embeddings from today."""
29
+ base_path = '/data2/edwardsun/generated_samples'
30
+ today = '20250829'
31
+
32
+ files = [
33
+ f'generated_amps_best_model_no_cfg_{today}.pt',
34
+ f'generated_amps_best_model_weak_cfg_{today}.pt',
35
+ f'generated_amps_best_model_strong_cfg_{today}.pt',
36
+ f'generated_amps_best_model_very_strong_cfg_{today}.pt'
37
+ ]
38
+
39
+ all_embeddings = []
40
+ all_labels = []
41
+
42
+ for file in files:
43
+ file_path = os.path.join(base_path, file)
44
+ if os.path.exists(file_path):
45
+ print(f"Loading {file}...")
46
+ embeddings = torch.load(file_path, map_location='cpu')
47
+
48
+ # Extract config type from filename
49
+ if 'no_cfg' in file:
50
+ cfg_type = 'no_cfg'
51
+ elif 'weak_cfg' in file:
52
+ cfg_type = 'weak_cfg'
53
+ elif 'strong_cfg' in file and 'very' not in file:
54
+ cfg_type = 'strong_cfg'
55
+ elif 'very_strong_cfg' in file:
56
+ cfg_type = 'very_strong_cfg'
57
+
58
+ # Each file contains 20 sequences
59
+ for i in range(embeddings.shape[0]):
60
+ all_embeddings.append(embeddings[i])
61
+ all_labels.append(f"{cfg_type}_{i+1}")
62
+
63
+ print(f"βœ“ Loaded {len(all_embeddings)} embeddings total")
64
+ return all_embeddings, all_labels
65
+
66
+ def decode_embeddings_to_sequences(embeddings, labels):
67
+ """Decode embeddings to sequences."""
68
+ print("Initializing sequence decoder...")
69
+ decoder = EmbeddingToSequenceConverter(device='cuda')
70
+
71
+ sequences = []
72
+ sequence_ids = []
73
+
74
+ print("Decoding embeddings to sequences...")
75
+ for i, (embedding, label) in enumerate(tqdm(zip(embeddings, labels), total=len(embeddings))):
76
+ # Decode using diverse method for better results
77
+ sequence = decoder.embedding_to_sequence(
78
+ embedding,
79
+ method='diverse',
80
+ temperature=0.8
81
+ )
82
+ sequences.append(sequence)
83
+ sequence_ids.append(f"generated_seq_{i+1}_{label}")
84
+
85
+ return sequences, sequence_ids
86
+
87
+ def save_sequences_as_fasta(sequences, sequence_ids, filename):
88
+ """Save sequences as FASTA file."""
89
+ records = []
90
+ for seq_id, seq in zip(sequence_ids, sequences):
91
+ record = SeqRecord(Seq(seq), id=seq_id, description="")
92
+ records.append(record)
93
+
94
+ SeqIO.write(records, filename, "fasta")
95
+ print(f"βœ“ Saved {len(sequences)} sequences to {filename}")
96
+
97
+ def test_with_hmd_amp(sequences, sequence_ids):
98
+ """Test sequences with HMD-AMP classifier."""
99
+ print("\n🧬 Testing sequences with HMD-AMP classifier...")
100
+
101
+ # Set device
102
+ device = "cuda" if torch.cuda.is_available() else "cpu"
103
+
104
+ # Load models
105
+ ftmodel_save_path = '/home/edwardsun/flow/HMD-AMP/AMP/ft_parts.pth'
106
+ clsmodel_save_path = '/home/edwardsun/flow/HMD-AMP/AMP/clsmodel'
107
+
108
+ # Create temporary FASTA file for HMD-AMP
109
+ temp_fasta = 'temp_sequences.fasta'
110
+ save_sequences_as_fasta(sequences, sequence_ids, temp_fasta)
111
+
112
+ try:
113
+ # Generate sequence features using HMD-AMP's feature extraction
114
+ seq_embeddings, _, seq_ids = amp_feature_extraction(ftmodel_save_path, device, temp_fasta)
115
+
116
+ # Load classifier
117
+ cls_model = CascadeForestClassifier()
118
+ cls_model.load(clsmodel_save_path)
119
+
120
+ # Make predictions
121
+ binary_pred = cls_model.predict(seq_embeddings)
122
+
123
+ print(f"πŸ“Š HMD-AMP Results:")
124
+ print(f"Total sequences: {len(sequences)}")
125
+ print(f"Predicted AMPs: {np.sum(binary_pred)} ({np.sum(binary_pred)/len(sequences)*100:.1f}%)")
126
+ print(f"Predicted non-AMPs: {len(sequences) - np.sum(binary_pred)} ({(len(sequences) - np.sum(binary_pred))/len(sequences)*100:.1f}%)")
127
+
128
+ # Analyze results by CFG type
129
+ results_df = pd.DataFrame({
130
+ 'ID': sequence_ids,
131
+ 'Sequence': sequences,
132
+ 'AMP_Prediction': binary_pred,
133
+ 'CFG_Type': [seq_id.split('_')[-2] for seq_id in sequence_ids]
134
+ })
135
+
136
+ # Group by CFG type
137
+ cfg_analysis = results_df.groupby('CFG_Type')['AMP_Prediction'].agg(['count', 'sum', 'mean']).round(3)
138
+ cfg_analysis.columns = ['Total', 'Predicted_AMPs', 'AMP_Rate']
139
+
140
+ print(f"\nπŸ“‹ Results by CFG Configuration:")
141
+ print(cfg_analysis)
142
+
143
+ # Show predicted AMPs
144
+ amp_results = results_df[results_df['AMP_Prediction'] == 1]
145
+ if len(amp_results) > 0:
146
+ print(f"\nπŸ† Sequences predicted as AMPs ({len(amp_results)}):")
147
+ for idx, row in amp_results.iterrows():
148
+ seq = row['Sequence']
149
+ cationic = seq.count('K') + seq.count('R')
150
+ net_charge = seq.count('K') + seq.count('R') + seq.count('H') - seq.count('D') - seq.count('E')
151
+ print(f" {row['ID']}: {seq}")
152
+ print(f" Length: {len(seq)}, Cationic (K+R): {cationic}, Net charge: {net_charge:+d}")
153
+ else:
154
+ print(f"\n❌ No sequences predicted as AMPs")
155
+
156
+ # Save detailed results
157
+ results_df.to_csv('hmd_amp_detailed_results.csv', index=False)
158
+ cfg_analysis.to_csv('hmd_amp_cfg_analysis.csv')
159
+
160
+ print(f"\nπŸ’Ύ Results saved:")
161
+ print(f" - hmd_amp_detailed_results.csv (detailed per-sequence results)")
162
+ print(f" - hmd_amp_cfg_analysis.csv (summary by CFG type)")
163
+
164
+ return results_df, cfg_analysis
165
+
166
+ finally:
167
+ # Clean up temporary file
168
+ if os.path.exists(temp_fasta):
169
+ os.remove(temp_fasta)
170
+
171
+ def main():
172
+ print("πŸš€ Starting sequence decoding and HMD-AMP testing...")
173
+
174
+ # Load embeddings
175
+ embeddings, labels = load_generated_embeddings()
176
+
177
+ # Decode to sequences
178
+ sequences, sequence_ids = decode_embeddings_to_sequences(embeddings, labels)
179
+
180
+ # Save sequences as FASTA
181
+ fasta_filename = f'generated_sequences_{datetime.now().strftime("%Y%m%d_%H%M%S")}.fasta'
182
+ save_sequences_as_fasta(sequences, sequence_ids, fasta_filename)
183
+
184
+ # Test with HMD-AMP
185
+ results_df, cfg_analysis = test_with_hmd_amp(sequences, sequence_ids)
186
+
187
+ print(f"\nβœ… Complete! Generated and tested {len(sequences)} sequences")
188
+ print(f"πŸ“ Sequences saved as: {fasta_filename}")
189
+
190
+ # Final summary
191
+ total_amps = results_df['AMP_Prediction'].sum()
192
+ print(f"\nπŸ“Š FINAL SUMMARY:")
193
+ print(f"Generated sequences: {len(sequences)}")
194
+ print(f"HMD-AMP predicted AMPs: {total_amps}/{len(sequences)} ({total_amps/len(sequences)*100:.1f}%)")
195
+
196
+ if total_amps > 0:
197
+ print(f"✨ Success! Your flow model generated {total_amps} sequences that HMD-AMP classifies as AMPs!")
198
+ else:
199
+ print(f"πŸ” No sequences classified as AMPs - this may indicate the need for stronger AMP conditioning.")
200
+
201
+ if __name__ == "__main__":
202
+ main()