Commit e9ec7fcf authored by Brewer, Wes's avatar Brewer, Wes
Browse files

Save model as jit trace

parent b6e8a224
Loading
Loading
Loading
Loading
+7 −7
Original line number Diff line number Diff line
@@ -47,15 +47,15 @@ model_module = importlib.import_module('archs.' + args.arch)
torch_model = model_module.build_model(input_shape)
print(torch_model)

model_path = f'{args.arch}_model.pth'
torch_model.load_state_dict(torch.load(model_path))
torch_model.eval()  # Set the model to evaluation mode
# Load the TorchScript model
model_path = f'{args.arch}_model.jit'
torch_model = torch.jit.load(model_path)
torch_model.eval()  # Ensure the model is in evaluation mode

# Convert the model to TorchScript
example_forward_input = torch.rand(models[args.arch]['shape']) 
module = torch.jit.trace(torch_model, example_forward_input)
# Serialize the loaded TorchScript model into a byte buffer
model_buffer = io.BytesIO()
torch.jit.save(module, model_buffer)
torch.jit.save(torch_model, model_buffer)
model_buffer.seek(0)  # Reset buffer position to the beginning

# Get the database address and create a SmartRedis client
client = Client(address="localhost:6780", cluster=False)
+5 −3
Original line number Diff line number Diff line
@@ -62,6 +62,8 @@ for epoch in range(epochs):
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

# Save model - PyTorch way
torch.save(model.state_dict(), f"{args.arch}_model.pth")
# Save model
model.eval() 
example_input = torch.rand(1, *input_shape)
scripted_model = torch.jit.trace(model, example_input)
scripted_model.save(f"{args.arch}_model.jit")