Spaces:
Sleeping
Sleeping
Àlex Solé
fixed output of wrong predictions, limiting the max volume possible for a prediction
58deabf
| import streamlit as st | |
| import torch | |
| from torch_geometric.data import Data, Batch | |
| from ase.io import write | |
| from ase import Atoms | |
| import gc | |
| from io import BytesIO, StringIO | |
| from utils import radius_graph_pbc | |
| MEAN_TEMP = torch.tensor(192.1785) #training temp mean | |
| STD_TEMP = torch.tensor(81.2135) #training temp std | |
| def process_ase(atoms, temperature, model): | |
| data = Data() | |
| data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32) | |
| data.pos = torch.tensor(atoms.positions, dtype=torch.float32) | |
| data.temperature_og = torch.tensor([temperature], dtype=torch.float32) | |
| data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP | |
| data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0) | |
| data.pbc = torch.tensor([True, True, True]) | |
| data.natoms = len(atoms) | |
| del atoms | |
| gc.collect() | |
| batch = Batch.from_data_list([data]) | |
| edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64) | |
| del batch | |
| gc.collect() | |
| data.cart_dist = torch.norm(edge_attr, dim=-1) | |
| data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1) | |
| data.edge_index = edge_index | |
| data.non_H_mask = data.x != 1 | |
| delattr(data, "pbc") | |
| delattr(data, "natoms") | |
| batch = Batch.from_data_list([data]) | |
| del data, edge_index, edge_attr | |
| gc.collect() | |
| st.success("Graph successfully created.") | |
| cif_file = process_data(batch, model) | |
| st.success("ADPs successfully predicted.") | |
| return cif_file | |
| def process_data(batch, model): | |
| atoms = batch.x.numpy().astype(int) # Atomic numbers | |
| positions = batch.pos.numpy() # Atomic positions | |
| cell = batch.cell.squeeze(0).numpy() # Cell parameters | |
| temperature = batch.temperature_og.numpy()[0] | |
| adps = model(batch) | |
| # Convert Ucart to Ucif | |
| M = batch.cell.squeeze(0) | |
| N = torch.diag(torch.linalg.norm(torch.linalg.inv(M.transpose(-1,-2)).squeeze(0), dim=-1)) | |
| M = torch.linalg.inv(M) | |
| N = torch.linalg.inv(N) | |
| adps = M.transpose(-1,-2)@adps@M | |
| adps = N.transpose(-1,-2)@adps@N | |
| del M, N | |
| gc.collect() | |
| non_H_mask = batch.non_H_mask.numpy() | |
| indices = torch.arange(len(atoms))[non_H_mask].numpy() | |
| indices = {indices[i]: i for i in range(len(indices))} | |
| # Create ASE Atoms object | |
| ase_atoms = Atoms(numbers=atoms, positions=positions, cell=cell, pbc=True) | |
| # Convert positions to fractional coordinates | |
| fractional_positions = ase_atoms.get_scaled_positions() | |
| # Instead of reading from file, get CIF content directly from ASE's write function | |
| cif_content = BytesIO() | |
| write(cif_content, ase_atoms, format='cif') | |
| lines = cif_content.getvalue().decode('utf-8').splitlines(True) | |
| cif_content.close() | |
| # Find the line where "loop_" appears and remove lines from there to the end | |
| for i, line in enumerate(lines): | |
| if line.strip().startswith('loop_'): | |
| lines = lines[:i] | |
| break | |
| # Use StringIO to build the CIF content | |
| cif_file = StringIO() | |
| cif_file.writelines(lines) | |
| # Write temperature | |
| cif_file.write(f"\n_diffrn_ambient_temperature {temperature}\n") | |
| # Write atomic positions | |
| cif_file.write("\nloop_\n") | |
| cif_file.write("_atom_site_label\n") | |
| cif_file.write("_atom_site_type_symbol\n") | |
| cif_file.write("_atom_site_fract_x\n") | |
| cif_file.write("_atom_site_fract_y\n") | |
| cif_file.write("_atom_site_fract_z\n") | |
| cif_file.write("_atom_site_U_iso_or_equiv\n") | |
| cif_file.write("_atom_site_thermal_displace_type\n") | |
| element_count = {} | |
| labels_uiso = [] | |
| for i, (atom_number, frac_pos) in enumerate(zip(atoms, fractional_positions)): | |
| element = ase_atoms[i].symbol | |
| assert atom_number == ase_atoms[i].number | |
| if element not in element_count: | |
| element_count[element] = 0 | |
| element_count[element] += 1 | |
| label = f"{element}{element_count[element]}" | |
| u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.0001 | |
| if u_iso > 1: | |
| labels_uiso.append(label) | |
| u_iso = 0.0001 | |
| type = "Uani" if (element != 'H' or u_iso > 1) else "Uiso" | |
| cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n") | |
| # Write ADPs | |
| cif_file.write("\nloop_\n") | |
| cif_file.write("_atom_site_aniso_label\n") | |
| cif_file.write("_atom_site_aniso_U_11\n") | |
| cif_file.write("_atom_site_aniso_U_22\n") | |
| cif_file.write("_atom_site_aniso_U_33\n") | |
| cif_file.write("_atom_site_aniso_U_23\n") | |
| cif_file.write("_atom_site_aniso_U_13\n") | |
| cif_file.write("_atom_site_aniso_U_12\n") | |
| element_count = {} | |
| total_adps = 0 | |
| for i, atom_number in enumerate(atoms): | |
| if atom_number == 1: | |
| continue | |
| total_adps += 1 | |
| element = ase_atoms[i].symbol | |
| if element not in element_count: | |
| element_count[element] = 0 | |
| element_count[element] += 1 | |
| label = f"{element}{element_count[element]}" | |
| if label in labels_uiso: | |
| continue | |
| cif_file.write(f"{label} {adps[indices[i],0,0]} {adps[indices[i],1,1]} {adps[indices[i],2,2]} {adps[indices[i],1,2]} {adps[indices[i],0,2]} {adps[indices[i],0,1]}\n") | |
| if len(labels_uiso) > 0: | |
| st.warning(f"Succesfully predicted {100*(total_adps-len(labels_uiso))/total_adps:.2f} % of ADPs") | |
| st.warning(f"CartNet produced unexpected ADPs for the following atoms (will be ignored in the output file): \n {', '.join(labels_uiso)}") | |
| return cif_file | |