Spaces:
Sleeping
Sleeping
Àlex Solé
commited on
Commit
·
fa790e2
1
Parent(s):
4e099cc
fixed bug multiple structures csd
Browse files- main.py +88 -69
- main_local.py +76 -78
- process.py +53 -54
- utils.py +1 -1
main.py
CHANGED
|
@@ -8,6 +8,7 @@ from models.master import create_model
|
|
| 8 |
from process import process_data
|
| 9 |
from utils import radius_graph_pbc
|
| 10 |
import gc
|
|
|
|
| 11 |
|
| 12 |
MEAN_TEMP = torch.tensor(192.1785) #training temp mean
|
| 13 |
STD_TEMP = torch.tensor(81.2135) #training temp std
|
|
@@ -16,7 +17,7 @@ STD_TEMP = torch.tensor(81.2135) #training temp std
|
|
| 16 |
@torch.no_grad()
|
| 17 |
def main():
|
| 18 |
model = create_model()
|
| 19 |
-
st.title("CartNet
|
| 20 |
st.image('fig/pipeline.png')
|
| 21 |
|
| 22 |
st.markdown("""
|
|
@@ -24,85 +25,101 @@ def main():
|
|
| 24 |
""")
|
| 25 |
|
| 26 |
uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
|
| 27 |
-
|
| 28 |
if uploaded_file is not None:
|
| 29 |
try:
|
| 30 |
-
with open(uploaded_file.name, "wb") as f:
|
| 31 |
-
f.write(uploaded_file.getbuffer())
|
| 32 |
filename = str(uploaded_file.name)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
cif = ReadCif(filename)
|
| 36 |
-
cif_data = cif.first_block()
|
| 37 |
-
|
| 38 |
-
if "_diffrn_ambient_temperature" in cif_data.keys():
|
| 39 |
-
temperature = float(cif_data["_diffrn_ambient_temperature"].split("(")[0])
|
| 40 |
-
elif "_cell_measurement_temperature" in cif_data.keys():
|
| 41 |
-
temperature = float(cif_data["_cell_measurement_temperature"].split("(")[0])
|
| 42 |
-
else:
|
| 43 |
-
raise ValueError("Temperature not found in the CIF file. \
|
| 44 |
-
Please provide a temperature in the field _diffrn_ambient_temperature o in the field _cell_measurement_temperature from the CIF file.")
|
| 45 |
-
st.success("CIF file successfully read.")
|
| 46 |
-
|
| 47 |
-
data = Data()
|
| 48 |
-
data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
|
| 49 |
-
|
| 50 |
-
if len(atoms.positions) > 1000:
|
| 51 |
-
st.markdown("""
|
| 52 |
-
⚠️ **Warning**: The structure is too large. Please upload a smaller one or use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
|
| 53 |
-
""")
|
| 54 |
-
raise ValueError("Please provide a structure with less than 1000 atoms in the unit cell.")
|
| 55 |
-
|
| 56 |
-
data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
|
| 57 |
-
data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
|
| 58 |
-
data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
|
| 59 |
-
data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
del atoms
|
| 65 |
-
gc.collect()
|
| 66 |
-
batch = Batch.from_data_list([data])
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
st.download_button(
|
| 92 |
-
label="Download processed CIF file",
|
| 93 |
-
data=cif_contents,
|
| 94 |
-
file_name="output.cif",
|
| 95 |
-
mime="text/plain"
|
| 96 |
-
)
|
| 97 |
-
|
| 98 |
-
os.remove("output.cif")
|
| 99 |
-
os.remove(filename)
|
| 100 |
gc.collect()
|
| 101 |
except Exception as e:
|
| 102 |
st.error(f"An error occurred while reading the CIF file: {e}")
|
| 103 |
-
|
|
|
|
| 104 |
⚠️ **Warning**: This online web application is designed for structures with up to 1000 atoms in the unit cell. For larger structures, please use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
|
| 105 |
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
st.markdown("""
|
| 108 |
📌 The official implementation of the paper with all experiments can be found at [CartNet GitHub Repository](https://huggingface.co/spaces/alexsoleg/cartnet-demo/tree/main).
|
|
@@ -128,3 +145,5 @@ def main():
|
|
| 128 |
|
| 129 |
if __name__ == "__main__":
|
| 130 |
main()
|
|
|
|
|
|
|
|
|
| 8 |
from process import process_data
|
| 9 |
from utils import radius_graph_pbc
|
| 10 |
import gc
|
| 11 |
+
from io import BytesIO, StringIO
|
| 12 |
|
| 13 |
MEAN_TEMP = torch.tensor(192.1785) #training temp mean
|
| 14 |
STD_TEMP = torch.tensor(81.2135) #training temp std
|
|
|
|
| 17 |
@torch.no_grad()
|
| 18 |
def main():
|
| 19 |
model = create_model()
|
| 20 |
+
st.title("CartNet Thermal Ellipsoid Prediction")
|
| 21 |
st.image('fig/pipeline.png')
|
| 22 |
|
| 23 |
st.markdown("""
|
|
|
|
| 25 |
""")
|
| 26 |
|
| 27 |
uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
|
| 28 |
+
|
| 29 |
if uploaded_file is not None:
|
| 30 |
try:
|
|
|
|
|
|
|
| 31 |
filename = str(uploaded_file.name)
|
| 32 |
+
file = BytesIO(uploaded_file.getbuffer())
|
| 33 |
+
cif = ReadCif(file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
if len(cif.keys())>1:
|
| 36 |
+
st.warning("⚠️ **Warning**: Found " + str(len(cif.keys())) + " blocks in the CIF file. We will process all of them and export as separate CIF files.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
st.markdown(f"### CIF file: {filename}")
|
| 39 |
+
for key in cif.keys():
|
| 40 |
+
st.markdown(f"### Block: {key}")
|
| 41 |
+
try:
|
| 42 |
+
block = "data_"+str(key)+"\n"+ cif[key].printsection()
|
| 43 |
+
atoms = read(StringIO(block), format="cif")
|
| 44 |
+
|
| 45 |
+
if len(atoms.positions) > 1000:
|
| 46 |
+
st.error("""
|
| 47 |
+
⚠️ **Warning**: The structure is too large. Please upload a smaller one or use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
|
| 48 |
+
""")
|
| 49 |
+
continue
|
| 50 |
+
|
| 51 |
+
cif_data = cif[key]
|
| 52 |
+
if "_diffrn_ambient_temperature" in cif_data.keys():
|
| 53 |
+
temperature = float(cif_data["_diffrn_ambient_temperature"].split("(")[0])
|
| 54 |
+
elif "_cell_measurement_temperature" in cif_data.keys():
|
| 55 |
+
temperature = float(cif_data["_cell_measurement_temperature"].split("(")[0])
|
| 56 |
+
else:
|
| 57 |
+
st.error("Temperature not found in the CIF file. \
|
| 58 |
+
Please provide a temperature in the field _diffrn_ambient_temperature o in the field _cell_measurement_temperature from the CIF file.")
|
| 59 |
+
continue
|
| 60 |
+
st.success("CIF file successfully read.")
|
| 61 |
+
except Exception as e:
|
| 62 |
+
st.error(f"Error: {e}")
|
| 63 |
+
st.error(f"We couldn't find any structure for the block {key}. Please make sure the cif is compatible with ASE. If the error message is a blank line, it means ASE didn't found any coordinates.")
|
| 64 |
+
|
| 65 |
+
continue
|
| 66 |
+
|
| 67 |
+
data = Data()
|
| 68 |
+
data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
|
| 69 |
+
|
| 70 |
+
data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
|
| 71 |
+
data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
|
| 72 |
+
data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
|
| 73 |
+
data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
|
| 74 |
+
|
| 75 |
+
data.pbc = torch.tensor([True, True, True])
|
| 76 |
+
data.natoms = len(atoms)
|
| 77 |
+
|
| 78 |
+
del atoms
|
| 79 |
+
gc.collect()
|
| 80 |
+
batch = Batch.from_data_list([data])
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
|
| 84 |
+
del batch
|
| 85 |
+
gc.collect()
|
| 86 |
+
data.cart_dist = torch.norm(edge_attr, dim=-1)
|
| 87 |
+
data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
|
| 88 |
+
data.edge_index = edge_index
|
| 89 |
+
data.non_H_mask = data.x != 1
|
| 90 |
+
delattr(data, "pbc")
|
| 91 |
+
delattr(data, "natoms")
|
| 92 |
+
batch = Batch.from_data_list([data])
|
| 93 |
+
del data, edge_index, edge_attr
|
| 94 |
+
gc.collect()
|
| 95 |
+
|
| 96 |
+
st.success("Graph successfully created.")
|
| 97 |
+
|
| 98 |
+
cif_file = process_data(batch, model)
|
| 99 |
+
st.success("ADPs successfully predicted.")
|
| 100 |
+
|
| 101 |
+
cif_file = BytesIO(cif_file.getvalue().encode())
|
| 102 |
+
st.download_button(
|
| 103 |
+
label="Download processed CIF file",
|
| 104 |
+
data=cif_file,
|
| 105 |
+
file_name=f"output_{key}.cif",
|
| 106 |
+
mime="text/plain",
|
| 107 |
+
key=f"download_button_{key}"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
gc.collect()
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
gc.collect()
|
| 113 |
except Exception as e:
|
| 114 |
st.error(f"An error occurred while reading the CIF file: {e}")
|
| 115 |
+
|
| 116 |
+
st.warning("""
|
| 117 |
⚠️ **Warning**: This online web application is designed for structures with up to 1000 atoms in the unit cell. For larger structures, please use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
|
| 118 |
""")
|
| 119 |
+
|
| 120 |
+
st.warning("""
|
| 121 |
+
⚠️ **Warning**: We use [ASE library](https://wiki.fysik.dtu.dk/ase/) for reading the cif files, please make sure it is compatible.
|
| 122 |
+
""")
|
| 123 |
|
| 124 |
st.markdown("""
|
| 125 |
📌 The official implementation of the paper with all experiments can be found at [CartNet GitHub Repository](https://huggingface.co/spaces/alexsoleg/cartnet-demo/tree/main).
|
|
|
|
| 145 |
|
| 146 |
if __name__ == "__main__":
|
| 147 |
main()
|
| 148 |
+
|
| 149 |
+
|
main_local.py
CHANGED
|
@@ -17,7 +17,7 @@ STD_TEMP = torch.tensor(81.2135) #training temp std
|
|
| 17 |
@torch.no_grad()
|
| 18 |
def main():
|
| 19 |
model = create_model()
|
| 20 |
-
st.title("CartNet
|
| 21 |
st.image('fig/pipeline.png')
|
| 22 |
|
| 23 |
st.markdown("""
|
|
@@ -25,92 +25,86 @@ def main():
|
|
| 25 |
""")
|
| 26 |
|
| 27 |
uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
|
|
|
|
| 28 |
|
| 29 |
if uploaded_file is not None:
|
| 30 |
try:
|
| 31 |
filename = str(uploaded_file.name)
|
| 32 |
file = BytesIO(uploaded_file.getbuffer())
|
| 33 |
cif = ReadCif(file)
|
| 34 |
-
print(cif.keys())
|
| 35 |
-
if len(cif.keys())>1:
|
| 36 |
-
st.markdown("Found " + str(len(cif.keys())) + " blocks in the CIF file. We will process all of them and export as separate CIF files.")
|
| 37 |
-
for key in cif.keys():
|
| 38 |
-
print(key)
|
| 39 |
-
# print(cif[key])
|
| 40 |
-
block = "data_"+str(key)+"\n"+ cif[key].printsection()
|
| 41 |
-
atoms = read(StringIO(block), format="cif")
|
| 42 |
-
print("atoms")
|
| 43 |
-
print(atoms)
|
| 44 |
-
# atoms = read(atoms_2, format="cif")
|
| 45 |
-
# with open(uploaded_file.name, "wb") as f:
|
| 46 |
-
# f.write(uploaded_file.getbuffer())
|
| 47 |
-
# filename = str(uploaded_file.name)
|
| 48 |
-
# # Read the CIF file using ASE
|
| 49 |
-
# atoms = read(filename, format="cif")
|
| 50 |
-
# cif = ReadCif(filename)
|
| 51 |
-
# print(cif.keys())
|
| 52 |
-
# print(len(atoms))
|
| 53 |
-
# # st.markdown(cif)
|
| 54 |
-
# cif_data = cif
|
| 55 |
-
# st.markdown(f"### CIF file: {filename}")
|
| 56 |
-
# temperature = 100
|
| 57 |
-
# if "_diffrn_ambient_temperature" in cif_data.keys():
|
| 58 |
-
# temperature = float(cif_data["_diffrn_ambient_temperature"].split("(")[0])
|
| 59 |
-
# elif "_cell_measurement_temperature" in cif_data.keys():
|
| 60 |
-
# temperature = float(cif_data["_cell_measurement_temperature"].split("(")[0])
|
| 61 |
-
# else:
|
| 62 |
-
# raise ValueError("Temperature not found in the CIF file. \
|
| 63 |
-
# Please provide a temperature in the field _diffrn_ambient_temperature o in the field _cell_measurement_temperature from the CIF file.")
|
| 64 |
-
# st.success("CIF file successfully read.")
|
| 65 |
-
|
| 66 |
-
# data = Data()
|
| 67 |
-
# data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
|
| 68 |
-
|
| 69 |
-
# data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
|
| 70 |
-
# data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
|
| 71 |
-
# data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
|
| 72 |
-
# data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
|
| 73 |
-
|
| 74 |
-
# data.pbc = torch.tensor([True, True, True])
|
| 75 |
-
# data.natoms = len(atoms)
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
# batch = Batch.from_data_list([data])
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
# edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
|
| 83 |
-
# del batch
|
| 84 |
-
# gc.collect()
|
| 85 |
-
# data.cart_dist = torch.norm(edge_attr, dim=-1)
|
| 86 |
-
# data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
|
| 87 |
-
# data.edge_index = edge_index
|
| 88 |
-
# data.non_H_mask = data.x != 1
|
| 89 |
-
# delattr(data, "pbc")
|
| 90 |
-
# delattr(data, "natoms")
|
| 91 |
-
# batch = Batch.from_data_list([data])
|
| 92 |
-
# del data, edge_index, edge_attr
|
| 93 |
-
# gc.collect()
|
| 94 |
-
|
| 95 |
-
# st.success("Graph successfully created.")
|
| 96 |
-
|
| 97 |
-
# process_data(batch, model)
|
| 98 |
-
# st.success("ADPs successfully predicted.")
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
-
|
| 105 |
-
# label="Download processed CIF file",
|
| 106 |
-
# data=cif_contents,
|
| 107 |
-
# file_name="output.cif",
|
| 108 |
-
# mime="text/plain"
|
| 109 |
-
# )
|
| 110 |
-
|
| 111 |
-
# os.remove("output.cif")
|
| 112 |
-
# os.remove(filename)
|
| 113 |
-
# gc.collect()
|
| 114 |
except Exception as e:
|
| 115 |
st.error(f"An error occurred while reading the CIF file: {e}")
|
| 116 |
|
|
@@ -119,6 +113,10 @@ def main():
|
|
| 119 |
📌 The official implementation of the paper with all experiments can be found at [CartNet GitHub Repository](https://github.com/imatge-upc/CartNet).
|
| 120 |
""")
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
st.markdown("""
|
| 123 |
### How to cite
|
| 124 |
|
|
|
|
| 17 |
@torch.no_grad()
|
| 18 |
def main():
|
| 19 |
model = create_model()
|
| 20 |
+
st.title("CartNet Thermal Ellipsoid Prediction")
|
| 21 |
st.image('fig/pipeline.png')
|
| 22 |
|
| 23 |
st.markdown("""
|
|
|
|
| 25 |
""")
|
| 26 |
|
| 27 |
uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
|
| 28 |
+
|
| 29 |
|
| 30 |
if uploaded_file is not None:
|
| 31 |
try:
|
| 32 |
filename = str(uploaded_file.name)
|
| 33 |
file = BytesIO(uploaded_file.getbuffer())
|
| 34 |
cif = ReadCif(file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
if len(cif.keys())>1:
|
| 37 |
+
st.warning("⚠️ **Warning**: Found " + str(len(cif.keys())) + " blocks in the CIF file. We will process all of them and export as separate CIF files.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
st.markdown(f"### CIF file: {filename}")
|
| 40 |
+
for key in cif.keys():
|
| 41 |
+
st.markdown(f"### Block: {key}")
|
| 42 |
+
try:
|
| 43 |
+
block = "data_"+str(key)+"\n"+ cif[key].printsection()
|
| 44 |
+
atoms = read(StringIO(block), format="cif")
|
| 45 |
+
|
| 46 |
+
cif_data = cif[key]
|
| 47 |
+
if "_diffrn_ambient_temperature" in cif_data.keys():
|
| 48 |
+
temperature = float(cif_data["_diffrn_ambient_temperature"].split("(")[0])
|
| 49 |
+
elif "_cell_measurement_temperature" in cif_data.keys():
|
| 50 |
+
temperature = float(cif_data["_cell_measurement_temperature"].split("(")[0])
|
| 51 |
+
else:
|
| 52 |
+
st.error("Temperature not found in the CIF file. \
|
| 53 |
+
Please provide a temperature in the field _diffrn_ambient_temperature o in the field _cell_measurement_temperature from the CIF file.")
|
| 54 |
+
continue
|
| 55 |
+
st.success("CIF file successfully read.")
|
| 56 |
+
except Exception as e:
|
| 57 |
+
st.error(f"Error: {e}")
|
| 58 |
+
st.error(f"We couldn't find any structure for the block {key}. Please make sure the cif is compatible with ASE. If the error message is a blank line, it means ASE didn't found any coordinates.")
|
| 59 |
+
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
data = Data()
|
| 63 |
+
data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
|
| 64 |
+
|
| 65 |
+
data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
|
| 66 |
+
data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
|
| 67 |
+
data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
|
| 68 |
+
data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
|
| 69 |
+
|
| 70 |
+
data.pbc = torch.tensor([True, True, True])
|
| 71 |
+
data.natoms = len(atoms)
|
| 72 |
+
|
| 73 |
+
del atoms
|
| 74 |
+
gc.collect()
|
| 75 |
+
batch = Batch.from_data_list([data])
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
|
| 79 |
+
del batch
|
| 80 |
+
gc.collect()
|
| 81 |
+
data.cart_dist = torch.norm(edge_attr, dim=-1)
|
| 82 |
+
data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
|
| 83 |
+
data.edge_index = edge_index
|
| 84 |
+
data.non_H_mask = data.x != 1
|
| 85 |
+
delattr(data, "pbc")
|
| 86 |
+
delattr(data, "natoms")
|
| 87 |
+
batch = Batch.from_data_list([data])
|
| 88 |
+
del data, edge_index, edge_attr
|
| 89 |
+
gc.collect()
|
| 90 |
+
|
| 91 |
+
st.success("Graph successfully created.")
|
| 92 |
+
|
| 93 |
+
cif_file = process_data(batch, model)
|
| 94 |
+
st.success("ADPs successfully predicted.")
|
| 95 |
+
|
| 96 |
+
cif_file = BytesIO(cif_file.getvalue().encode())
|
| 97 |
+
st.download_button(
|
| 98 |
+
label="Download processed CIF file",
|
| 99 |
+
data=cif_file,
|
| 100 |
+
file_name=f"output_{key}.cif",
|
| 101 |
+
mime="text/plain",
|
| 102 |
+
key=f"download_button_{key}"
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
gc.collect()
|
| 106 |
|
| 107 |
+
gc.collect()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
except Exception as e:
|
| 109 |
st.error(f"An error occurred while reading the CIF file: {e}")
|
| 110 |
|
|
|
|
| 113 |
📌 The official implementation of the paper with all experiments can be found at [CartNet GitHub Repository](https://github.com/imatge-upc/CartNet).
|
| 114 |
""")
|
| 115 |
|
| 116 |
+
st.warning("""
|
| 117 |
+
⚠️ **Warning**: We use [ASE library](https://wiki.fysik.dtu.dk/ase/) for reading the cif files, please make sure it is compatible.
|
| 118 |
+
""")
|
| 119 |
+
|
| 120 |
st.markdown("""
|
| 121 |
### How to cite
|
| 122 |
|
process.py
CHANGED
|
@@ -2,6 +2,7 @@ import torch
|
|
| 2 |
from ase.io import write
|
| 3 |
from ase import Atoms
|
| 4 |
import gc
|
|
|
|
| 5 |
|
| 6 |
@torch.no_grad()
|
| 7 |
def process_data(batch, model, output_file="output.cif"):
|
|
@@ -35,11 +36,12 @@ def process_data(batch, model, output_file="output.cif"):
|
|
| 35 |
# Convert positions to fractional coordinates
|
| 36 |
fractional_positions = ase_atoms.get_scaled_positions()
|
| 37 |
|
| 38 |
-
# Write to CIF file
|
| 39 |
-
write(output_file, ase_atoms)
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
# Find the line where "loop_" appears and remove lines from there to the end
|
| 45 |
for i, line in enumerate(lines):
|
|
@@ -47,54 +49,51 @@ def process_data(batch, model, output_file="output.cif"):
|
|
| 47 |
lines = lines[:i]
|
| 48 |
break
|
| 49 |
|
| 50 |
-
#
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
label = f"{element}{element_count[element]}"
|
| 77 |
-
u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.01
|
| 78 |
-
type = "Uani" if element != 'H' else "Uiso"
|
| 79 |
-
cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n")
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
| 2 |
from ase.io import write
|
| 3 |
from ase import Atoms
|
| 4 |
import gc
|
| 5 |
+
from io import BytesIO, StringIO
|
| 6 |
|
| 7 |
@torch.no_grad()
|
| 8 |
def process_data(batch, model, output_file="output.cif"):
|
|
|
|
| 36 |
# Convert positions to fractional coordinates
|
| 37 |
fractional_positions = ase_atoms.get_scaled_positions()
|
| 38 |
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
# Instead of reading from file, get CIF content directly from ASE's write function
|
| 41 |
+
cif_content = BytesIO()
|
| 42 |
+
write(cif_content, ase_atoms, format='cif')
|
| 43 |
+
lines = cif_content.getvalue().decode('utf-8').splitlines(True)
|
| 44 |
+
cif_content.close()
|
| 45 |
|
| 46 |
# Find the line where "loop_" appears and remove lines from there to the end
|
| 47 |
for i, line in enumerate(lines):
|
|
|
|
| 49 |
lines = lines[:i]
|
| 50 |
break
|
| 51 |
|
| 52 |
+
# Use StringIO to build the CIF content
|
| 53 |
+
cif_file = StringIO()
|
| 54 |
+
cif_file.writelines(lines)
|
| 55 |
+
# Write temperature
|
| 56 |
+
cif_file.write(f"\n_diffrn_ambient_temperature {temperature}\n")
|
| 57 |
+
# Write atomic positions
|
| 58 |
+
cif_file.write("\nloop_\n")
|
| 59 |
+
cif_file.write("_atom_site_label\n")
|
| 60 |
+
cif_file.write("_atom_site_type_symbol\n")
|
| 61 |
+
cif_file.write("_atom_site_fract_x\n")
|
| 62 |
+
cif_file.write("_atom_site_fract_y\n")
|
| 63 |
+
cif_file.write("_atom_site_fract_z\n")
|
| 64 |
+
cif_file.write("_atom_site_U_iso_or_equiv\n")
|
| 65 |
+
cif_file.write("_atom_site_thermal_displace_type\n")
|
| 66 |
+
|
| 67 |
+
element_count = {}
|
| 68 |
+
for i, (atom_number, frac_pos) in enumerate(zip(atoms, fractional_positions)):
|
| 69 |
+
element = ase_atoms[i].symbol
|
| 70 |
+
assert atom_number == ase_atoms[i].number
|
| 71 |
+
if element not in element_count:
|
| 72 |
+
element_count[element] = 0
|
| 73 |
+
element_count[element] += 1
|
| 74 |
+
label = f"{element}{element_count[element]}"
|
| 75 |
+
u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.01
|
| 76 |
+
type = "Uani" if element != 'H' else "Uiso"
|
| 77 |
+
cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
+
# Write ADPs
|
| 80 |
+
cif_file.write("\nloop_\n")
|
| 81 |
+
cif_file.write("_atom_site_aniso_label\n")
|
| 82 |
+
cif_file.write("_atom_site_aniso_U_11\n")
|
| 83 |
+
cif_file.write("_atom_site_aniso_U_22\n")
|
| 84 |
+
cif_file.write("_atom_site_aniso_U_33\n")
|
| 85 |
+
cif_file.write("_atom_site_aniso_U_23\n")
|
| 86 |
+
cif_file.write("_atom_site_aniso_U_13\n")
|
| 87 |
+
cif_file.write("_atom_site_aniso_U_12\n")
|
| 88 |
+
|
| 89 |
+
element_count = {}
|
| 90 |
+
for i, atom_number in enumerate(atoms):
|
| 91 |
+
if atom_number == 1:
|
| 92 |
+
continue
|
| 93 |
+
element = ase_atoms[i].symbol
|
| 94 |
+
if element not in element_count:
|
| 95 |
+
element_count[element] = 0
|
| 96 |
+
element_count[element] += 1
|
| 97 |
+
label = f"{element}{element_count[element]}"
|
| 98 |
+
cif_file.write(f"{label} {adps[indices[i],0,0]} {adps[indices[i],1,1]} {adps[indices[i],2,2]} {adps[indices[i],1,2]} {adps[indices[i],0,2]} {adps[indices[i],0,1]}\n")
|
| 99 |
+
return cif_file
|
utils.py
CHANGED
|
@@ -264,7 +264,7 @@ def get_max_neighbors_mask(
|
|
| 264 |
+ torch.arange(len(index), device=device)
|
| 265 |
- index_neighbor_offset_expand
|
| 266 |
)
|
| 267 |
-
|
| 268 |
distance_sort.index_copy_(0, index_sort_map, atom_distance)
|
| 269 |
distance_sort = distance_sort.view(num_atoms, max_num_neighbors)
|
| 270 |
|
|
|
|
| 264 |
+ torch.arange(len(index), device=device)
|
| 265 |
- index_neighbor_offset_expand
|
| 266 |
)
|
| 267 |
+
|
| 268 |
distance_sort.index_copy_(0, index_sort_map, atom_distance)
|
| 269 |
distance_sort = distance_sort.view(num_atoms, max_num_neighbors)
|
| 270 |
|