ONNX & ONNX Runtime en 1 page
Format d'échange interopérable pour modèles de deep learning, avec runtime optimisé multi-cible.
C’est quoi ?
ONNX (Open Neural Network Exchange) = format de fichier .onnx qui décrit un graphe de calcul indépendamment du framework d’origine (PyTorch, TensorFlow, JAX…).
ONNX Runtime = moteur d’inférence qui lit ce format et l’exécute de façon optimisée sur CPU, GPU, NPU, etc.
Export depuis PyTorch
import torch
import torch.onnx
model.eval()
dummy_input = torch.randn(1, 80, 3000) # adapter à votre modèle
torch.onnx.export(
model,
dummy_input,
"model.onnx",
input_names=["input"],
output_names=["logits"],
dynamic_axes={"input": {0: "batch", 2: "time"}},
opset_version=17,
)
opset_version : utiliser 17+ pour les ops récents. Vérifier la compatibilité ONNX Runtime cible.
Inférence avec ONNX Runtime
import onnxruntime as ort
import numpy as np
sess = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
inputs = {sess.get_inputs()[0].name: audio_array.astype(np.float32)}
outputs = sess.run(None, inputs)
logits = outputs[0]
Quantification INT8
Dynamique (recommandé pour débuter)
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic(
model_input="model.onnx",
model_output="model_int8.onnx",
weight_type=QuantType.QInt8,
)
Pas besoin de données de calibration. Poids quantifiés, activations en FP32 à l’exécution.
Statique (meilleure perf, nécessite calibration)
from onnxruntime.quantization import quantize_static, CalibrationDataReader
class MyCalib(CalibrationDataReader):
def get_next(self):
# retourner dict {input_name: np.array}
...
quantize_static("model.onnx", "model_int8.onnx", MyCalib())
Providers disponibles
| Provider | Cible | Import |
|---|---|---|
CPUExecutionProvider | CPU (défaut) | toujours disponible |
CUDAExecutionProvider | GPU NVIDIA | onnxruntime-gpu |
CoreMLExecutionProvider | Apple Silicon | macOS uniquement |
QNNExecutionProvider | Snapdragon NPU | qualcomm-ai-engine |
TensorrtExecutionProvider | GPU NVIDIA (TRT) | onnxruntime-gpu |
Passer la liste par ordre de préférence : ORT choisit le premier disponible.
sess = ort.InferenceSession(
"model.onnx",
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)
Outils utiles
# Inspecter le graphe
pip install netron
netron model.onnx
# Vérifier le modèle
python -m onnxruntime.tools.check_onnx_model model.onnx
# Profiler l'inférence
sess_options = ort.SessionOptions()
sess_options.enable_profiling = True
sess = ort.InferenceSession("model.onnx", sess_options)
# → génère un fichier JSON chargeable dans chrome://tracing
Pièges courants
- Axes dynamiques oubliés → erreur à l’inférence si batch ≠ 1. Toujours déclarer
dynamic_axes. - Opset trop ancien → certains ops (LayerNorm fusionné, Attention) nécessitent opset ≥ 14.
- Quantification sur CPU arm64 → préférer
QUInt8pour les poids,QInt8pour les activations. - Pic mémoire sous-estimé → ONNX Runtime alloue des buffers intermédiaires. Mesurer avec
tracemallocou/proc/meminfo.