The script gets the DeepLabV3 + MobileNet model from the pytorch.
It then saves it to .pt
which can be opened in Netron.
It then converts it to .mlpackage
which cannot be opened in Netron - it gives me an error saying "File has no content" even though its about 22 MB.
Keep in mind the input should be an image type rather than multi array.
Ive tried this with many different versions of torch and torchvision. The latest version of torch that coreml supports is 2.4 and a lower version of torchvision is required because of it.
pip install "torch==2.4" "torchvision==0.19"
pip install coremltools
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import coremltools as ct
# Load the pretrained DeepLabV3 model with MobileNetV3 backbone
model = torch.hub.load('pytorch/vision:v0.19.0', 'deeplabv3_mobilenet_v3_large', pretrained=True)
model.eval()
# wrapper class
class Deeplabv3Wrapper(nn.Module):
def __init__(self, model):
super(Deeplabv3Wrapper, self).__init__()
self.model = model
def forward(self, x):
outputs = self.model(x)
return outputs['out'] # Return the 'out' tensor for segmentation mask
# Define preprocessing transforms
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # values from pytorch docs
])
dummy_image = Image.new("RGB", (513, 513), color="white")
dummy_tensor = preprocess(dummy_image).unsqueeze(0) # Add batch dimension
# Wrap and trace the model
wrapped_model = Deeplabv3Wrapper(model)
# Trace the model using TorchScript
traced_model = torch.jit.trace(wrapped_model, dummy_tensor)
print("Model tracing completed")
# Save the traced model
torch.jit.save(traced_model, "traced_deeplabv3.pt") # i CAN open this in netron.app
print("Model successfully saved as pt")
traced_model.eval()
example_input = torch.randn(1, 3, 513, 513)
model_from_trace = ct.convert(
traced_model,
inputs=[ct.ImageType(name="input_image", shape=example_input.shape, scale=1/255.0, bias=[-0.485, -0.456, -0.406])],
outputs=[ct.TensorType(name="output")]
)
print("Model conversion completed")
model_from_trace.save("DeepLabV3_MobileNet.mlpackage") # i CANNOT open this in netron.app - Error: "File has no content"
print("Model successfully saved as mlpackage")
发布者:admin,转转请注明出处:http://www.yc00.com/questions/1745671799a4639448.html
评论列表(0条)