Spaces:
Sleeping
Sleeping
| import argparse | |
| import os | |
| from ase.io import read | |
| from CifFile import ReadCif | |
| from torch_geometric.data import Data, Batch | |
| import torch | |
| from models.master import create_model | |
| from process import process_data | |
| from utils import radius_graph_pbc | |
| import gc | |
| MEAN_TEMP = torch.tensor(192.1785) #training temp mean | |
| STD_TEMP = torch.tensor(81.2135) #training temp std | |
| def process_cif(input_file, output_file): | |
| model = create_model() | |
| try: | |
| # Read the CIF file using ASE | |
| atoms = read(input_file, format="cif") | |
| cif = ReadCif(input_file) | |
| cif_data = cif.first_block() | |
| if "_diffrn_ambient_temperature" in cif_data.keys(): | |
| temperature = float(cif_data["_diffrn_ambient_temperature"]) | |
| else: | |
| raise ValueError("Temperature not found in the CIF file. \ | |
| Please provide a temperature in the field _diffrn_ambient_temperature from the CIF file.") | |
| data = Data() | |
| data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32) | |
| if len(atoms.positions) > 300: | |
| raise ValueError("This implementation is not optimized for large systems. For large systems, please use the local version.") | |
| data.pos = torch.tensor(atoms.positions, dtype=torch.float32) | |
| data.temperature_og = torch.tensor([temperature], dtype=torch.float32) | |
| data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP | |
| data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0) | |
| data.pbc = torch.tensor([True, True, True]) | |
| data.natoms = len(atoms) | |
| del atoms | |
| gc.collect() | |
| batch = Batch.from_data_list([data]) | |
| edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64) | |
| del batch | |
| gc.collect() | |
| data.cart_dist = torch.norm(edge_attr, dim=-1) | |
| data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1) | |
| data.edge_index = edge_index | |
| data.non_H_mask = data.x != 1 | |
| delattr(data, "pbc") | |
| delattr(data, "natoms") | |
| batch = Batch.from_data_list([data]) | |
| del data, edge_index, edge_attr | |
| gc.collect() | |
| process_data(batch, model, output_file) | |
| gc.collect() | |
| except Exception as e: | |
| print(f"An error occurred while processing the CIF file: {e}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Process a CIF file and output the result.") | |
| parser.add_argument("input_file", type=str, help="Path to the input CIF file.") | |
| parser.add_argument("output_file", type=str, help="Path to the output CIF file.") | |
| args = parser.parse_args() | |
| process_cif(args.input_file, args.output_file) |