· Intermédiaire

ONNX & ONNX Runtime en 1 page

Format d'échange interopérable pour modèles de deep learning, avec runtime optimisé multi-cible.

C’est quoi ?

ONNX (Open Neural Network Exchange) = format de fichier .onnx qui décrit un graphe de calcul indépendamment du framework d’origine (PyTorch, TensorFlow, JAX…).

ONNX Runtime = moteur d’inférence qui lit ce format et l’exécute de façon optimisée sur CPU, GPU, NPU, etc.


Export depuis PyTorch

import torch
import torch.onnx

model.eval()
dummy_input = torch.randn(1, 80, 3000)  # adapter à votre modèle

torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["input"],
    output_names=["logits"],
    dynamic_axes={"input": {0: "batch", 2: "time"}},
    opset_version=17,
)

opset_version : utiliser 17+ pour les ops récents. Vérifier la compatibilité ONNX Runtime cible.


Inférence avec ONNX Runtime

import onnxruntime as ort
import numpy as np

sess = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])

inputs = {sess.get_inputs()[0].name: audio_array.astype(np.float32)}
outputs = sess.run(None, inputs)
logits = outputs[0]

Quantification INT8

Dynamique (recommandé pour débuter)

from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(
    model_input="model.onnx",
    model_output="model_int8.onnx",
    weight_type=QuantType.QInt8,
)

Pas besoin de données de calibration. Poids quantifiés, activations en FP32 à l’exécution.

Statique (meilleure perf, nécessite calibration)

from onnxruntime.quantization import quantize_static, CalibrationDataReader

class MyCalib(CalibrationDataReader):
    def get_next(self):
        # retourner dict {input_name: np.array}
        ...

quantize_static("model.onnx", "model_int8.onnx", MyCalib())

Providers disponibles

ProviderCibleImport
CPUExecutionProviderCPU (défaut)toujours disponible
CUDAExecutionProviderGPU NVIDIAonnxruntime-gpu
CoreMLExecutionProviderApple SiliconmacOS uniquement
QNNExecutionProviderSnapdragon NPUqualcomm-ai-engine
TensorrtExecutionProviderGPU NVIDIA (TRT)onnxruntime-gpu

Passer la liste par ordre de préférence : ORT choisit le premier disponible.

sess = ort.InferenceSession(
    "model.onnx",
    providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)

Outils utiles

# Inspecter le graphe
pip install netron
netron model.onnx

# Vérifier le modèle
python -m onnxruntime.tools.check_onnx_model model.onnx

# Profiler l'inférence
sess_options = ort.SessionOptions()
sess_options.enable_profiling = True
sess = ort.InferenceSession("model.onnx", sess_options)
# → génère un fichier JSON chargeable dans chrome://tracing

Pièges courants

  • Axes dynamiques oubliés → erreur à l’inférence si batch ≠ 1. Toujours déclarer dynamic_axes.
  • Opset trop ancien → certains ops (LayerNorm fusionné, Attention) nécessitent opset ≥ 14.
  • Quantification sur CPU arm64 → préférer QUInt8 pour les poids, QInt8 pour les activations.
  • Pic mémoire sous-estimé → ONNX Runtime alloue des buffers intermédiaires. Mesurer avec tracemalloc ou /proc/meminfo.