diff --git a/Dockerfile b/Dockerfile index 8367f80..cf77d8a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,13 +4,21 @@ FROM nvcr.io/nvidia/tensorrt:23.08-py3 # Install system packages RUN apt-get update && apt-get install -y \ python3-pip \ - git + git \ + libjpeg-dev \ + libpng-dev + +# Copy the requirements.txt file into the container +COPY requirements.txt /workspace/requirements.txt + +# Install Python packages +RUN pip3 install --no-cache-dir -r /workspace/requirements.txt + +# Install torch-tensorrt from the special location +RUN pip3 install torch-tensorrt -f https://github.com/NVIDIA/Torch-TensorRT/releases # Set the working directory WORKDIR /workspace # Copy local project files to /workspace in the image -COPY . /workspace - -# Install Python packages -RUN pip3 install --no-cache-dir -r /workspace/requirements.txt \ No newline at end of file +COPY . /workspace \ No newline at end of file diff --git a/README.md b/README.md index 9fbd3fd..b5904e9 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,31 @@ -# ResNet-50 Inference with ONNX/TensorRT + + + ## Table of Contents 1. [Overview](#overview) 2. [Requirements](#requirements) -3. [Steps to Run](#steps-to-run) -4. [Example Command](#example-command) -5. [Inference Benchmark Results](#inference-benchmark-results) - - [Example of Results](#example-of-results) - - [Explanation of Results](#explanation-of-results) -6. [ONNX Exporter](#onnx-exporter) ![New](https://img.shields.io/badge/-New-red) -7. [Author](#author) -8. [References](#references) + - [Steps to Run](#steps-to-run) + - [Example Command](#example-command) +5. [RESULTS](#results) ![Static Badge](https://img.shields.io/badge/update-yellow) + - [Results explanation](#results-explanation) + - [Example Input](#example-input) +6. [Benchmark Implementation Details](#benchmark-implementation-details) ![New](https://img.shields.io/badge/-New-red) + - [PyTorch CPU & CUDA](#pytorch-cpu--cuda) + - [TensorRT FP32 & FP16](#tensorrt-fp32--fp16) + - [ONNX](#onnx) + - [OpenVINO](#openvino) +7. [Used methodologies](#used-methodologies) ![New](https://img.shields.io/badge/-New-red) + - [TensorRT Optimization](#tensorrt-optimization) + - [ONNX Exporter](#onnx-exporter) + - [OV Exporter](#ov-exporter) +10. [Author](#author) +11. [References](#references) + + + ## Overview -This project demonstrates how to perform inference with a PyTorch model and optimize it using ONNX or NVIDIA TensorRT. The script loads a pre-trained ResNet-50 model from torchvision, performs inference on a user-provided image, and prints the top-K predicted classes. Additionally, the script benchmarks the model's performance in the following configurations: CPU, CPU (ONNX), CUDA, TensorRT-FP32, and TensorRT-FP16, providing insights into the speedup gained through optimization. +This project demonstrates how to perform inference with a PyTorch model and optimize it using ONNX, OpenVINO, NVIDIA TensorRT. The script loads a pre-trained ResNet-50 model from torchvision, performs inference on a user-provided image, and prints the top-K predicted classes. Additionally, the script benchmarks the model's performance in the following configurations: CPU, CUDA, TensorRT-FP32, and TensorRT-FP16, providing insights into the speedup gained through optimization. ## Requirements - This repo cloned @@ -21,7 +34,7 @@ This project demonstrates how to perform inference with a PyTorch model and opti - Python 3.x - [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#install-guide) (for running the Docker container with GPU support) -## Steps to Run +### Steps to Run ```sh # 1. Build the Docker Image @@ -37,46 +50,132 @@ python src/main.py ### Arguments - `--image_path`: (Optional) Specifies the path to the image you want to predict. - `--topk`: (Optional) Specifies the number of top predictions to show. Defaults to 5 if not provided. -- `--onnx`: (Optional) Specifies if we want export ResNet50 model to ONNX and run benchmark only for this model +- `--mode`: Specifies the mode for exporting and running the model. Choices are: `onnx`, `ov`, `all`. -## Example Command +### Example Command ```sh -python src/main.py --image_path ./inference/cat3.jpg --topk 3 --onnx +python src/main.py --topk 3 --mode=all ``` -This command will run predictions on the image at the specified path and show the top 3 predictions using both PyTorch and ONNX Runtime models. For the default 5 top predictions, omit the --topk argument or set it to 5. +This command will run predictions on the default image (`./inference/cat3.jpg`), show the top 3 predictions and run all models (PyTorch CPU, CUDA, ONNX, OV, TRT-FP16, TRT-FP32). At the end results plot will be saved to `./inference/plot.png` -## Inference Benchmark Results +## RESULTS +### Inference Benchmark Results + -The results of the predictions and benchmarks are saved to `model.log`. This log file contains information about the predicted class for the input image and the average batch time for the different configurations during the benchmark. +### Results explanation + - `PyTorch_cpu: 973.52 ms` indicate the average batch time when running `PyTorch` model on `CPU` device. + - `PyTorch_cuda: 41.11 ms` indicate the average batch time when running `PyTorch` model on `CUDA` device. + - `TRT_fp32: 19.10 ms` shows the average batch time when running the model with `TensorRT` using `float32` precision. + - `TRT_fp16: 7.22 ms` indicate the average batch time when running the model with `TensorRT` using `float16` precision. + - `ONNX: 15.38 ms` indicate the average batch inference time when running the `PyTorch` converted to the `ONNX` model on the `CPU` device. + - `OpenVINO: 14.04 ms` indicate the average batch inference time when running the `ONNX` model converted to `OpenVINO` on the `CPU` device. -### Example of Results -Here is an example of the contents of `model.log` after running predictions and benchmarks on this image: +### Example Input +Here is an example of the input image to run predictions and benchmarks on: +## Benchmark Implementation Details +Here you can see flow for each model and benchmark. + +### PyTorch CPU & CUDA +In the provided code, we perform inference using the native PyTorch framework on both CPU and GPU (CUDA) configurations. This serves as a baseline to compare the performance improvements gained from other optimization techniques. + +#### Flow: +1. The ResNet-50 model is loaded from torchvision and, if available, transferred to the GPU. +2. Inference is performed on the provided image using the specified model. +3. Benchmark results, including average inference time, are logged for both the CPU and CUDA setups. + +### TensorRT FP32 & FP16 +TensorRT offers significant performance improvements by optimizing the neural network model. In this code, we utilize TensorRT's capabilities to run benchmarks in both FP32 (single precision) and FP16 (half precision) modes. + +#### Flow: +1. Load the ResNet-50 model. +2. Convert the PyTorch model to TensorRT format with the specified precision. +3. Perform inference on the provided image. +4. Log the benchmark results for the specified TensorRT precision mode. + +### ONNX +The code includes an exporter that converts the PyTorch ResNet-50 model to ONNX format, allowing it to be inferred using ONNX Runtime. This provides a flexible, cross-platform solution for deploying the model. + +#### Flow: +1. The ResNet-50 model is loaded. +2. Using the ONNX exporter utility, the PyTorch model is converted to ONNX format. +3. ONNX Runtime session is created. +4. Inference is performed on the provided image using the ONNX model. +5. Benchmark results are logged for the ONNX model. + +### OpenVINO +OpenVINO is a toolkit from Intel that optimizes deep learning model inference for Intel CPUs, GPUs, and other hardware. In the code, we convert the ONNX model to OpenVINO's format and then run benchmarks using the OpenVINO runtime. + +#### Flow: +1. The ONNX model (created in the previous step) is loaded. +2. Convert the ONNX model to OpenVINO's IR format. +3. Create an inference engine using OpenVINO's runtime. +4. Perform inference on the provided image using the OpenVINO model. +5. Benchmark results, including average inference time, are logged for the OpenVINO model. + +## Used methodologies +### TensorRT Optimization +TensorRT is a high-performance deep learning inference optimizer and runtime library developed by NVIDIA. It is designed for optimizing and deploying trained neural network models on production environments. This project supports TensorRT optimizations in both FP32 (single precision) and FP16 (half precision) modes, offering different trade-offs between inference speed and model accuracy. + +#### Features +- **Performance Boost**: TensorRT can significantly accelerate the inference of neural network models, making it suitable for deployment in resource-constrained environments. +- **Precision Modes**: Supports FP32 for maximum accuracy and FP16 for faster performance with a minor trade-off in accuracy. +- **Layer Fusion**: TensorRT fuses layers and tensors in the neural network to reduce memory access overhead and improve execution speed. +- **Dynamic Tensor Memory**: Efficiently handles varying batch sizes without re-optimization. + +#### Usage +To employ TensorRT optimizations in the project, use the `--mode all` argument when running the main script. +This will initiate all models including PyTorch models that will be compiled to TRT model with `FP16` and `FP32` precision modes. Then, in one of the steps, will run inference on the specified image using the TensorRT-optimized model. +Example: +```sh +python src/main.py --mode all ``` -My prediction: %33 tabby -My prediction: %26 Egyptian cat -Running Benchmark for CPU -Average batch time: 942.47 ms -Average ONNX inference time: 15.59 ms -Running Benchmark for CUDA -Average batch time: 41.02 ms -Compiling and Running Inference Benchmark for TensorRT with precision: torch.float32 -Average batch time: 19.20 ms -Compiling and Running Inference Benchmark for TensorRT with precision: torch.float16 -Average batch time: 7.25 ms +#### Requirements +Ensure you have the TensorRT library and the torch_tensorrt package installed in your environment. Also, for FP16 optimizations, it's recommended to have a GPU that supports half-precision arithmetic (like NVIDIA GPUs with Tensor Cores). + +### ONNX Exporter +ONNX Model Exporter (`ONNXExporter`) utility is incorporated within this project to enable the conversion of the native PyTorch model into the ONNX format. +Using the ONNX format, inference and benchmarking can be performed with the ONNX Runtime, which offers platform-agnostic optimizations and is widely supported across numerous platforms and devices. + +#### Features +- **Standardized Format**: ONNX provides an open-source format for AI models. It defines an extensible computation graph model, as well as definitions of built-in operators and standard data types. +- **Interoperability**: Models in ONNX format can be used across a variety of frameworks, tools, runtimes, and compilers. +- **Optimizations**: The ONNX Runtime provides performance optimizations for both cloud and edge devices. + +#### Usage +To leverage the `ONNXExporter` and conduct inference using the ONNX Runtime, utilize the `--mode onnx` argument when executing the main script. +This will initiate the conversion process and then run inference on the specified image using the ONNX model. +Example: +```sh +python src/main.py --mode onnx ``` -### Explanation of Results -- First k lines show the topk predictions. For example, `My prediction: %33 tabby` displays the highest confidence prediction made by the model for the input image, confidence level (`%33`), and the predicted class (`tabby`). -- The following lines provide information about the average batch time for running the model in different configurations: - - `Running Benchmark for CPU` and `Average batch time: 942.47 ms` indicate the average batch time when running the model on the CPU. - - `Average ONNX inference time: 15.59 ms` indicate the average batch time when running the ONNX model on the CPU. - - `Running Benchmark for CUDA` and `Average batch time: 41.02 ms` indicate the average batch time when running the model on CUDA. - - `Compiling and Running Inference Benchmark for TensorRT with precision: torch.float32` and `Average batch time: 19.20 ms` show the average batch time when running the model with TensorRT using `float32` precision. - - `Compiling and Running Inference Benchmark for TensorRT with precision: torch.float16` and `Average batch time: 7.25 ms` indicate the average batch time when running the model with TensorRT using `float16` precision. +#### Requirements +Ensure the ONNX library is installed in your environment to use the ONNXExporter. Additionally, if you want to run inference using the ONNX model, make sure you have the ONNX Runtime installed. + +### OV Exporter +OpenVINO Model Exporter utility (`OVExporter`) has been integrated into this project to facilitate the conversion of the ONNX model to the OpenVINO format. +This enables inference and benchmarking using OpenVINO, a framework optimized for Intel hardware, providing substantial speed improvements especially on CPUs. + +#### Features +- **Model Optimization**: Converts the ONNX model to OpenVINO's Intermediate Representation (IR) format. This optimized format allows for faster inference times on Intel hardware. +- **Versatility**: OpenVINO can target a variety of Intel hardware devices such as CPUs, integrated GPUs, FPGAs, and VPUs. +- **Ease of Use**: The `OVExporter` provides a seamless transition from ONNX to OpenVINO, abstracting the conversion details and providing a straightforward interface. + +#### Usage +To utilize `OVExporter` and perform inference using OpenVINO, use the `--mode ov` argument when running the main script. +This will trigger the conversion process and subsequently run inference on the provided image using the optimized OpenVINO model. +Example: +```sh +python src/main.py --mode ov +``` + +#### Requirements +Ensure you have the OpenVINO Toolkit installed and the necessary dependencies set up to use OpenVINO's model optimizer and inference engine. + ## ONNX Exporter The ONNX Exporter utility is integrated into this project to allow the conversion of the PyTorch model to ONNX format, enabling inference and benchmarking using ONNX Runtime. The ONNX model can provide hardware-agnostic optimizations and is widely supported across various platforms and devices. diff --git a/inference/logo.png b/inference/logo.png new file mode 100644 index 0000000..02b10dd Binary files /dev/null and b/inference/logo.png differ diff --git a/inference/plot.png b/inference/plot.png new file mode 100644 index 0000000..43bb7bb Binary files /dev/null and b/inference/plot.png differ diff --git a/requirements.txt b/requirements.txt index 91b1f90..dee912e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,11 @@ torch torchvision -torch-tensorrt pandas Pillow numpy packaging onnx -onnxruntime \ No newline at end of file +onnxruntime +openvino==2023.1.0.dev20230811 +seaborn +matplotlib diff --git a/src/benchmark.py b/src/benchmark.py index 4708c28..cf24764 100644 --- a/src/benchmark.py +++ b/src/benchmark.py @@ -7,6 +7,7 @@ import torch.backends.cudnn as cudnn import logging import onnxruntime as ort +import openvino as ov # Configure logging logging.basicConfig(filename="model.log", level=logging.INFO) @@ -22,7 +23,7 @@ def __init__(self, nruns: int = 100, nwarmup: int = 50): self.nwarmup = nwarmup @abstractmethod - def run(self) -> None: + def run(self): """ Abstract method to run the benchmark. """ @@ -58,7 +59,7 @@ def __init__( cudnn.benchmark = True # Enable cuDNN benchmarking optimization - def run(self) -> None: + def run(self): """ Run the benchmark with the given model, input shape, and other parameters. Log the average batch time and print the input shape and output feature size. @@ -93,6 +94,7 @@ def run(self) -> None: print(f"Input shape: {input_data.size()}") print(f"Output features size: {features.size()}") logging.info(f"Average batch time: {np.mean(timings) * 1000:.2f} ms") + return np.mean(timings) * 1000 class ONNXBenchmark(Benchmark): @@ -113,7 +115,8 @@ def __init__( self.nwarmup = nwarmup self.nruns = nruns - def run(self) -> None: + + def run(self): print("Warming up ...") # Adjusting the batch size in the input shape to match the expected input size of the model. input_shape = (1,) + self.input_shape[1:] @@ -133,3 +136,64 @@ def run(self) -> None: avg_time = np.mean(timings) * 1000 logging.info(f"Average ONNX inference time: {avg_time:.2f} ms") + return avg_time + + +class OVBenchmark(Benchmark): + def __init__( + self, model: ov.frontend.FrontEnd, input_shape: Tuple[int, int, int, int] + ): + """ + Initialize the OVBenchmark with the OpenVINO model and the input shape. + + :param model: ov.frontend.FrontEnd + The OpenVINO model. + :param input_shape: Tuple[int, int, int, int] + The shape of the model input. + """ + self.ov_model = model + self.core = ov.Core() + self.compiled_model = None + self.input_shape = input_shape + self.warmup_runs = 50 + self.num_runs = 100 + self.dummy_input = np.random.randn(*input_shape).astype(np.float32) + + def warmup(self): + """ + Compile the OpenVINO model for optimal execution on available hardware. + """ + self.compiled_model = self.core.compile_model(self.ov_model, "AUTO") + + def inference(self, input_data) -> dict: + """ + Perform inference on the input data using the compiled OpenVINO model. + + :param input_data: np.ndarray + The input data for the model. + :return: dict + The model's output as a dictionary. + """ + outputs = self.compiled_model(inputs={"input": input_data}) + return outputs + + def run(self): + """ + Run the benchmark on the OpenVINO model. It first warms up by compiling the model and then measures + the average inference time over a set number of runs. + """ + # Warm-up runs + logging.info("Warming up ...") + for _ in range(self.warmup_runs): + self.warmup() + + # Benchmarking + total_time = 0 + for _ in range(self.num_runs): + start_time = time.time() + _ = self.inference(self.dummy_input) + total_time += time.time() - start_time + + avg_time = total_time / self.num_runs + logging.info(f"Average inference time: {avg_time * 1000:.2f} ms") + return avg_time * 1000 \ No newline at end of file diff --git a/src/main.py b/src/main.py index 7507fdf..aed87bf 100644 --- a/src/main.py +++ b/src/main.py @@ -1,17 +1,21 @@ import argparse import os import logging -import onnx +import pandas as pd +import openvino as ov import torch import torch_tensorrt -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Dict, Any import onnxruntime as ort import numpy as np +import seaborn as sns +import matplotlib.pyplot as plt from model import ModelLoader from image_processor import ImageProcessor -from benchmark import PyTorchBenchmark, ONNXBenchmark +from benchmark import PyTorchBenchmark, ONNXBenchmark, OVBenchmark from onnx_exporter import ONNXExporter +from ov_exporter import OVExporter # Configure logging logging.basicConfig(filename="model.log", level=logging.INFO) @@ -43,7 +47,7 @@ def run_benchmark( def make_prediction( - model: Union[torch.nn.Module, ort.InferenceSession], + model: Union[torch.nn.Module, ort.InferenceSession, ov.CompiledModel], img_batch: Union[torch.Tensor, np.ndarray], topk: int, categories: List[str], @@ -59,6 +63,7 @@ def make_prediction( :param precision: The data type to be used for the predictions (typically torch.float32 or torch.float16) for PyTorch models. """ is_onnx_model = isinstance(model, ort.InferenceSession) + is_ov_model = isinstance(model, ov.CompiledModel) if is_onnx_model: # Get the input name for the ONNX model. @@ -78,10 +83,23 @@ def make_prediction( # Apply Softmax to get probabilities prob = np.exp(prob) / np.sum(np.exp(prob)) + elif is_ov_model: + # For OV, the input name is usually the first input + input_name = next(iter(model.inputs)) + outputs = model(inputs={input_name: img_batch}) - else: # PyTorch Model - img_batch = img_batch.clone().to(precision) + # Assuming the model returns a dictionary with one key for class probabilities + prob_key = next(iter(outputs)) + prob = outputs[prob_key] + + # Apply Softmax to get probabilities + prob = np.exp(prob[0]) / np.sum(np.exp(prob[0])) + else: # PyTorch Model + if isinstance(img_batch, np.ndarray): + img_batch = torch.tensor(img_batch) + else: + img_batch = img_batch.clone().to(precision) model.eval() with torch.no_grad(): outputs = model(img_batch.to(precision)) @@ -101,6 +119,91 @@ def make_prediction( logging.info(f"#{i + 1}: {int(probability * 100)}% {class_label}") +def run_all_benchmarks( + models: Dict[str, Any], img_batch: np.ndarray +) -> Dict[str, float]: + """ + Run benchmarks for all models and return a dictionary of average inference times. + + :param models: Dictionary of models. Key is model type ("onnx", "ov", "pytorch", "trt_fp32", "trt_fp16"), value is the model. + :param img_batch: The batch of images to run the benchmark on. + :return: Dictionary of average inference times. Key is model type, value is average inference time. + """ + results = {} + + # ONNX benchmark + onnx_benchmark = ONNXBenchmark(models["onnx"], img_batch.shape) + avg_time_onnx = onnx_benchmark.run() + results["ONNX"] = avg_time_onnx + + # OpenVINO benchmark + ov_benchmark = OVBenchmark(models["ov"], img_batch.shape) + avg_time_ov = ov_benchmark.run() + results["OpenVINO"] = avg_time_ov + + # PyTorch + TRT benchmark + configs = [ + ("cpu", torch.float32, False), + ("cuda", torch.float32, False), + ("cuda", torch.float32, True), + ("cuda", torch.float16, True), + ] + for device, precision, is_trt in configs: + model_to_use = models["pytorch"].to(device) + + if not is_trt: + pytorch_benchmark = PyTorchBenchmark( + model_to_use, device=device, dtype=precision + ) + avg_time_pytorch = pytorch_benchmark.run() + results[f"PyTorch_{device}"] = avg_time_pytorch + + else: + # TensorRT benchmarks + if precision == torch.float32 or precision == torch.float16: + mode = "fp32" if precision == torch.float32 else "fp16" + trt_benchmark = PyTorchBenchmark( + models[f"trt_{mode}"], device=device, dtype=precision + ) + avg_time_trt = trt_benchmark.run() + results[f"TRT_{mode}"] = avg_time_trt + + return results + + +def plot_benchmark_results(results: Dict[str, float]): + """ + Plot the benchmark results using Seaborn. + + :param results: Dictionary of average inference times. Key is model type, value is average inference time. + """ + # Convert dictionary to two lists for plotting + models = list(results.keys()) + times = list(results.values()) + + # Create a DataFrame for plotting + data = pd.DataFrame({"Model": models, "Time": times}) + + # Sort the DataFrame by Time + data = data.sort_values("Time", ascending=True) + + # Plot + plt.figure(figsize=(10, 6)) + ax = sns.barplot(x=data["Time"], y=data["Model"], palette="rocket") + + # Adding the actual values on the bars + for index, value in enumerate(data["Time"]): + ax.text(value, index, f"{value:.2f} ms", color="black", ha="left", va="center") + + plt.xlabel("Average Inference Time (ms)") + plt.ylabel("Model Type") + plt.title("ResNet50 - Inference Benchmark Results") + + # Save the plot to a file + plt.savefig("./inference/plot.png", bbox_inches="tight") + plt.show() + + def main() -> None: """ Main function to run inference, benchmarks, and predictions on the model @@ -117,18 +220,23 @@ def main() -> None: parser.add_argument( "--topk", type=int, default=5, help="Number of top predictions to show" ) - parser.add_argument( - "--onnx", action="store_true", help="If we want export model to ONNX format" - ) parser.add_argument( "--onnx_path", type=str, default="./inference/model.onnx", help="Path where model in ONNX format will be exported", ) + parser.add_argument( + "--mode", + choices=["onnx", "ov", "cuda", "all"], + required=True, + help="Mode for exporting and running the model. Choices are: onnx, ov, cuda or all.", + ) args = parser.parse_args() + models = {} + # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -137,21 +245,22 @@ def main() -> None: img_processor = ImageProcessor(img_path=args.image_path, device=device) img_batch = img_processor.process_image() - if args.onnx: + if args.mode == "onnx" or args.mode == "all": onnx_path = args.onnx_path - if not os.path.exists(onnx_path): - # Export the model to ONNX format using ONNXExporter - onnx_exporter = ONNXExporter(model_loader.model, device, onnx_path) - onnx_exporter.export_model() + # Export the model to ONNX format using ONNXExporter + onnx_exporter = ONNXExporter(model_loader.model, device, onnx_path) + onnx_exporter.export_model() # Create ONNX Runtime session ort_session = ort.InferenceSession( onnx_path, providers=["CPUExecutionProvider"] ) + models["onnx"] = ort_session + # Run benchmark - run_benchmark(None, None, None, ort_session, onnx=True) + # run_benchmark(None, None, None, ort_session, onnx=True) # Make prediction print(f"Making prediction with {ort.get_device()} for ONNX model") @@ -161,7 +270,29 @@ def main() -> None: topk=args.topk, categories=model_loader.categories, ) - else: + if args.mode == "ov" or args.mode == "all": + # Export the ONNX model to OpenVINO + ov_exporter = OVExporter(args.onnx_path) + ov_model = ov_exporter.export_model() + + models["ov"] = ov_model + + # Benchmark the OpenVINO model + ov_benchmark = OVBenchmark(ov_model, input_shape=(1, 3, 224, 224)) + ov_benchmark.run() + + # Run inference using the OpenVINO model + img_batch_ov = ( + img_processor.process_image().cpu().numpy() + ) # Assuming batch size of 1 + print(f"Making prediction with OpenVINO model") + make_prediction( + ov_benchmark.compiled_model, + img_batch_ov, + topk=args.topk, + categories=model_loader.categories, + ) + if args.mode == "cuda" or args.mode == "all": # Define configurations for which to run benchmarks and make predictions configs = [ ("cpu", torch.float32), @@ -170,26 +301,31 @@ def main() -> None: ] for device, precision in configs: - model = model_loader.model.to(device) + model_to_use = model_loader.model.to(device) + models["pytorch"] = model_loader.model if device == "cuda": print(f"Tracing {device} model") - model = torch.jit.trace( - model, [torch.randn((1, 3, 224, 224)).to(device)] + model_to_use = torch.jit.trace( + model_to_use, [torch.randn((1, 3, 224, 224)).to(device)] ) - if device == "cuda" and precision == torch.float16: + if precision == torch.float32 or precision == torch.float16: print("Compiling TensorRT model") - model = torch_tensorrt.compile( - model, + model_to_use = torch_tensorrt.compile( + model_to_use, inputs=[torch_tensorrt.Input((32, 3, 224, 224), dtype=precision)], enabled_precisions={precision}, truncate_long_and_double=True, ) + if precision == torch.float32: + models["trt_fp32"] = model_to_use + else: + models["trt_fp16"] = model_to_use - print(f"Making prediction with {device} model in {precision} precision") + """print(f"Making prediction with {device} model in {precision} precision") make_prediction( - model, + model_to_use, img_batch.to(device), args.topk, model_loader.categories, @@ -197,7 +333,13 @@ def main() -> None: ) print(f"Running Benchmark for {device} model in {precision} precision") - run_benchmark(model, device, precision) + run_benchmark(model_to_use, device, precision) """ + if args.mode == "all": + # Run all benchmarks + results = run_all_benchmarks(models, img_batch) + + # Plot results + plot_benchmark_results(results) if __name__ == "__main__": diff --git a/src/ov_exporter.py b/src/ov_exporter.py new file mode 100644 index 0000000..9e5f594 --- /dev/null +++ b/src/ov_exporter.py @@ -0,0 +1,32 @@ +import os +import openvino as ov + + +class OVExporter: + """ + OVExporter handles the conversion of an ONNX model to OpenVINO's internal representation. + """ + + def __init__(self, onnx_model_path: str): + """ + Initialize the OVExporter with the path to the ONNX model. + + :param onnx_model_path: str + Path to the ONNX model file. + """ + self.onnx_path = onnx_model_path + self.core = ov.Core() + + def export_model(self) -> ov.Model: + """ + Convert the ONNX model to OpenVINO's internal representation. + + :return: ov.ie.IENetwork + The converted OpenVINO model. + """ + if not os.path.isfile(self.onnx_path): + raise ValueError(f"ONNX model wasn't found in path: {self.onnx_path}") + + # Convert the ONNX model to OpenVINO's internal representation + ov_model = self.core.read_model(self.onnx_path) + return ov_model