| import os | |
| import json | |
| import shutil | |
| from optimum.exporters.onnx import main_export | |
| import onnx | |
| from onnxconverter_common import float16 | |
| import onnxruntime as rt | |
| from onnxruntime.tools.onnx_model_utils import * | |
| from onnxruntime.quantization import quantize_dynamic, QuantType | |
| from huggingface_hub import hf_hub_download | |
| import transformers | |
| with open('conversion_config.json') as json_file: | |
| conversion_config = json.load(json_file) | |
| model_id = conversion_config["model_id"] | |
| number_of_generated_embeddings = conversion_config["number_of_generated_embeddings"] | |
| precision_to_filename_map = conversion_config["precision_to_filename_map"] | |
| opset = conversion_config["opset"] | |
| IR = conversion_config["IR"] | |
| op = onnx.OperatorSetIdProto() | |
| op.version = opset | |
| print("Exporting tokenizer...") | |
| tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) | |
| tokenizer.save_pretrained("./") | |
| print("Done\n\n") | |
| if not os.path.exists("onnx"): | |
| os.makedirs("onnx") | |
| if "fp32" in precision_to_filename_map: | |
| print("Exporting the fp32 onnx file...") | |
| filename = precision_to_filename_map['fp32'] | |
| hf_hub_download(repo_id=model_id, filename=filename, local_dir = "./") | |
| model = onnx.load(filename) | |
| model_fixed = onnx.helper.make_model(model.graph, ir_version = IR, opset_imports = [op]) #to be sure that we have compatible opset and IR version | |
| onnx.save(model_fixed, filename) | |
| print("Done\n\n") | |
| if "int8" in precision_to_filename_map: | |
| print("Exporting the int8 onnx file...") | |
| filename = precision_to_filename_map['int8'] | |
| hf_hub_download(repo_id=model_id, filename=filename, local_dir = "./") | |
| model = onnx.load(filename) | |
| model_fixed = onnx.helper.make_model(model.graph, ir_version = IR, opset_imports = [op]) #to be sure that we have compatible opset and IR version | |
| onnx.save(model_fixed, filename) | |
| print("Done\n\n") | |
| if "uint8" in precision_to_filename_map: | |
| print("Exporting the uint8 onnx file...") | |
| filename = precision_to_filename_map['uint8'] | |
| hf_hub_download(repo_id=model_id, filename=filename, local_dir = "./") | |
| model = onnx.load(filename) | |
| model_fixed = onnx.helper.make_model(model.graph, ir_version = IR, opset_imports = [op]) #to be sure that we have compatible opset and IR version | |
| onnx.save(model_fixed, filename) | |
| print("Done\n\n") | |