maomao88 commited on
Commit
24b8880
·
1 Parent(s): b0223ae

add support for vlm

Browse files
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ backend/__pycache__/
3
+ *.pyc
4
+
5
+ # PyCharm settings (can be sensitive or specific to your machine)
6
+ .idea/
7
+
8
+ backend/.idea/
.idea/.gitignore DELETED
@@ -1,3 +0,0 @@
1
- # Default ignored files
2
- /shelf/
3
- /workspace.xml
 
 
 
 
.idea/inspectionProfiles/profiles_settings.xml DELETED
@@ -1,6 +0,0 @@
1
- <component name="InspectionProjectProfileManager">
2
- <settings>
3
- <option name="USE_PROJECT_PROFILE" value="false" />
4
- <version value="1.0" />
5
- </settings>
6
- </component>
 
 
 
 
 
 
 
.idea/misc.xml DELETED
@@ -1,7 +0,0 @@
1
- <?xml version="1.0" encoding="UTF-8"?>
2
- <project version="4">
3
- <component name="Black">
4
- <option name="sdkName" value="Python 3.13" />
5
- </component>
6
- <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.13 (model_structure_viewer)" project-jdk-type="Python SDK" />
7
- </project>
 
 
 
 
 
 
 
 
.idea/model_structure_viewer.iml DELETED
@@ -1,10 +0,0 @@
1
- <?xml version="1.0" encoding="UTF-8"?>
2
- <module type="PYTHON_MODULE" version="4">
3
- <component name="NewModuleRootManager">
4
- <content url="file://$MODULE_DIR$">
5
- <excludeFolder url="file://$MODULE_DIR$/.venv" />
6
- </content>
7
- <orderEntry type="jdk" jdkName="Python 3.13 (model_structure_viewer)" jdkType="Python SDK" />
8
- <orderEntry type="sourceFolder" forTests="false" />
9
- </component>
10
- </module>
 
 
 
 
 
 
 
 
 
 
 
.idea/modules.xml DELETED
@@ -1,8 +0,0 @@
1
- <?xml version="1.0" encoding="UTF-8"?>
2
- <project version="4">
3
- <component name="ProjectModuleManager">
4
- <modules>
5
- <module fileurl="file://$PROJECT_DIR$/.idea/model_structure_viewer.iml" filepath="$PROJECT_DIR$/.idea/model_structure_viewer.iml" />
6
- </modules>
7
- </component>
8
- </project>
 
 
 
 
 
 
 
 
 
.idea/vcs.xml DELETED
@@ -1,6 +0,0 @@
1
- <?xml version="1.0" encoding="UTF-8"?>
2
- <project version="4">
3
- <component name="VcsDirectoryMappings">
4
- <mapping directory="" vcs="Git" />
5
- </component>
6
- </project>
 
 
 
 
 
 
 
backend/__pycache__/app.cpython-313.pyc DELETED
Binary file (2.22 kB)
 
backend/__pycache__/hf_model_utils.cpython-313.pyc DELETED
Binary file (6.87 kB)
 
backend/hf_model_utils.py CHANGED
@@ -3,7 +3,18 @@ import torch.nn as nn
3
  import json
4
  import hashlib
5
  import gc
6
- from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForImageClassification
 
 
 
 
 
 
 
 
 
 
 
7
  from accelerate import init_empty_weights
8
 
9
 
@@ -104,10 +115,20 @@ def get_model_structure(model_name: str, model_type: str | None):
104
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
105
  with init_empty_weights():
106
  model = AutoModelForImageClassification.from_config(config)
 
 
 
 
107
  else:
108
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
109
- with torch.device("meta"):
110
- model = AutoModel.from_config(config, trust_remote_code=True)
 
 
 
 
 
 
111
 
112
  structure = {
113
  "model_type": config.model_type,
 
3
  import json
4
  import hashlib
5
  import gc
6
+ from transformers import (
7
+ AutoConfig,
8
+ AutoModel,
9
+ AutoModelForCausalLM,
10
+ AutoModelForMaskedLM,
11
+ AutoModelForSequenceClassification,
12
+ AutoModelForTokenClassification,
13
+ AutoModelForQuestionAnswering,
14
+ AutoModelForSeq2SeqLM,
15
+ AutoModelForImageClassification,
16
+ AutoModelForImageTextToText
17
+ )
18
  from accelerate import init_empty_weights
19
 
20
 
 
115
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
116
  with init_empty_weights():
117
  model = AutoModelForImageClassification.from_config(config)
118
+ elif model_type == "vlm":
119
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
120
+ with init_empty_weights():
121
+ model = AutoModelForImageTextToText.from_config(config, trust_remote_code=True)
122
  else:
123
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
124
+ if hasattr(config, "vision_config"):
125
+ # It's a VLM
126
+ with init_empty_weights():
127
+ model = AutoModelForImageTextToText.from_config(config, trust_remote_code=True)
128
+ else:
129
+ # It's a standard model
130
+ with init_empty_weights():
131
+ model = AutoModel.from_config(config, trust_remote_code=True)
132
 
133
  structure = {
134
  "model_type": config.model_type,
backend/requirements.txt CHANGED
@@ -31,7 +31,7 @@ sympy==1.14.0
31
  tokenizers==0.22.0
32
  torch==2.8.0
33
  tqdm==4.67.1
34
- transformers==4.56.1
35
  typing-inspection==0.4.1
36
  typing_extensions==4.15.0
37
  urllib3==2.5.0
@@ -39,4 +39,4 @@ uvicorn==0.35.0
39
  uvloop==0.21.0
40
  watchfiles==1.1.0
41
  websockets==15.0.1
42
- accelerate==1.10.1
 
31
  tokenizers==0.22.0
32
  torch==2.8.0
33
  tqdm==4.67.1
34
+ transformers>=4.57.0
35
  typing-inspection==0.4.1
36
  typing_extensions==4.15.0
37
  urllib3==2.5.0
 
39
  uvloop==0.21.0
40
  watchfiles==1.1.0
41
  websockets==15.0.1
42
+ accelerate==1.12.0
frontend/src/components/ModelInputBar.jsx CHANGED
@@ -11,6 +11,7 @@ export default function ModelInputBar({ loading, fetchModelStructure }) {
11
  { label: "Question Answering Models (e.g. BERT QA, RoBERTa QA)", value: "qa", default: "distilbert-base-uncased-distilled-squad" },
12
  { label: "Seq2Seq (encoder-decoder, e.g. T5, BART, MarianMT)", value: "s2s", default: "t5-base" },
13
  { label: "Vision models (image classification, CLIP vision tower, etc.)", value: "vision", default: "google/vit-base-patch16-224" },
 
14
  ];
15
 
16
  const [modelName, setModelName] = useState("");
 
11
  { label: "Question Answering Models (e.g. BERT QA, RoBERTa QA)", value: "qa", default: "distilbert-base-uncased-distilled-squad" },
12
  { label: "Seq2Seq (encoder-decoder, e.g. T5, BART, MarianMT)", value: "s2s", default: "t5-base" },
13
  { label: "Vision models (image classification, CLIP vision tower, etc.)", value: "vision", default: "google/vit-base-patch16-224" },
14
+ { label: "Vision-Language models (image-text models, Qwen-VL, etc.)", value: "vlm", default: "Qwen/Qwen3-VL-8B-Instruct" },
15
  ];
16
 
17
  const [modelName, setModelName] = useState("");
package-lock.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "name": "model_structure_viewer",
3
+ "lockfileVersion": 3,
4
+ "requires": true,
5
+ "packages": {}
6
+ }