-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathImageClassifier.py
More file actions
191 lines (154 loc) · 7.54 KB
/
Copy pathImageClassifier.py
File metadata and controls
191 lines (154 loc) · 7.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import os
import json
from transformers import AutoImageProcessor, ConvNextV2ForImageClassification
import torch
import numpy as np
from PIL import Image
from MeshRenderer import MeshRenderer
from LabelProcessor import LabelProcessor
class ImageClassifier:
def __init__(self, model_name=None, device="cuda", num_views=12, resolution=1024):
"""Create an image-based classifier using a pretrained ConvNeXtV2 model.
Args:
model_name (str): Path or name of the pretrained model.
device (str): Device to load the model on (e.g. "cuda" or "cpu").
num_views (int): Number of rendered views to use per mesh.
resolution (int): Rendering resolution for each view (pixels).
"""
self.device = device
model_name = model_name or os.path.normpath(
os.path.join(os.path.dirname(__file__), "..", "models", "convnextv2-large-22k-384")
)
self.model = ConvNextV2ForImageClassification.from_pretrained(model_name, device_map={"": device})
self.model.eval()
self.processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)
self.mesh_renderer = MeshRenderer(num_views=num_views, resolution=resolution)
def set_num_views(self, num_views):
"""Set number of rendering views and recreate the renderer.
Args:
num_views (int): New number of views to render per mesh.
"""
self.mesh_renderer = MeshRenderer(num_views=num_views, resolution=self.mesh_renderer.resolution)
def set_resolution(self, resolution):
"""Set rendering resolution and recreate the renderer.
Args:
resolution (int): New square resolution (pixels) for rendered views.
"""
self.mesh_renderer = MeshRenderer(num_views=self.mesh_renderer.num_views, resolution=resolution)
def classify_view(self, image):
"""Classify a single PIL image using the loaded model.
Args:
image (PIL.Image.Image): RGB image to classify.
Returns:
torch.Tensor: Probabilities over the model's label vocabulary.
"""
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
return probabilities
def classify_one(self, file_path):
"""Render a mesh and perform multi-view classification.
Args:
file_path (str): Path to a mesh file to render and classify.
Raises:
FileNotFoundError: If the mesh file does not exist.
Returns:
str: Predicted label (lowercased) for the mesh.
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File {file_path} does not exist.")
image_paths = self.mesh_renderer.render_views(file_path)
all_probabilities = []
for image_path in image_paths:
with Image.open(image_path) as image:
probabilities = self.classify_view(image.convert("RGB"))
all_probabilities.append(probabilities.cpu())
del probabilities
del image_paths
torch.cuda.empty_cache()
# Average probabilities across views
avg_probabilities = torch.mean(torch.cat(all_probabilities, dim=0), dim=0)
label_id = torch.argmax(avg_probabilities).item()
label = self.model.config.id2label.get(label_id, str(label_id)).lower() # returns num as default if label not found
print("\nMulti-view classification:")
print(label)
return label
def classify_batch(self, folder_path, limit=None, save_path=None):
"""Classify all mesh files in a folder (optionally limited) and collect results.
Args:
folder_path (str): Directory containing mesh files.
limit (int, optional): Maximum number of files to process. Defaults to None.
save_path (str, optional): Filename to save JSON results under `classifications/`.
Raises:
ValueError: If the provided `folder_path` does not exist.
Returns:
str: Status message (and save path when applicable).
"""
if not os.path.isdir(folder_path):
raise ValueError(f"Folder does not exist: {folder_path}")
files = []
for root, _, filenames in os.walk(folder_path):
files.extend(os.path.join(root, f) for f in filenames)
files.sort()
if limit is not None:
files = files[:limit]
classifications = []
label_processor = LabelProcessor()
for file in files:
print(f"\nClassifying {file}...")
raw_predicted_label = self.classify_one(file)
labels = [part.strip() for part in raw_predicted_label.split(",") if part.strip()]
embeddings = np.asarray(label_processor.compute_embeddings(raw_predicted_label), dtype=np.float32)
# Find representative embedding for the whole label by selecting the one with highest average similarity to the others
avg_similarities = []
for i, emb1 in enumerate(embeddings):
similarities = []
for j, emb2 in enumerate(embeddings):
if i != j:
similarity = label_processor.compute_similarity(emb1, emb2)
similarities.append(similarity)
avg_similarity = np.mean(similarities) if similarities else 0.0
avg_similarities.append(avg_similarity)
# Select the embedding with the highest average similarity to the others
closest_idx = int(np.argmax(avg_similarities))
embedding = embeddings[closest_idx]
predicted_label = labels[closest_idx]
print(f"Predicted label: {predicted_label}")
classifications.append({
'file_path': file,
'predicted_label': predicted_label,
'labels': labels,
'embeddings': embeddings.tolist(),
'embedding': embedding.tolist(),
})
if save_path:
# Create classifications directory if it doesn't exist
os.makedirs("classifications", exist_ok=True)
save_path = os.path.join("classifications", save_path)
with open(save_path, 'w') as f:
json.dump(classifications, f)
return 'Classification completed. Results saved to ' + save_path if save_path else 'Classification completed.'
def classify(self, file_path=None, folder_path=None, limit=None, save_path=None):
"""Dispatch helper to classify a single file or a folder.
Args:
file_path (str, optional): Single mesh file to classify.
folder_path (str, optional): Directory of mesh files to classify.
limit (int, optional): Limit for batch processing.
save_path (str, optional): Path to save batch results.
Raises:
ValueError: If neither `file_path` nor `folder_path` is provided.
Returns:
Any: Return value from `classify_one` or `classify_batch`.
"""
if file_path is not None:
return self.classify_one(file_path)
elif folder_path is not None:
return self.classify_batch(
folder_path,
limit=limit,
save_path=save_path,
)
else:
raise ValueError("Either file_path or folder_path must be provided.")