System UnrealPlugin InferenceInterfaces Onnx - kcccr123/ue-reinforcement-learning GitHub Wiki

ONNX Inference Implementation

UInferenceInterfaceOnnx is a concrete implementation of the UInferenceInterface using ONNX Runtime. It loads .onnx files and performs inference given a vector of float observations, returning a serialized action string.

This class serves as the default backend for model inference in the framework.


Overview

ONNX (Open Neural Network Exchange) is a widely used standard format for trained neural networks. This implementation wraps ONNX Runtime and plugs directly into the RL bridge infrastructure.

Key features:

  • Loads .onnx files from disk
  • Runs inference on CPU
  • Designed for 1D float observations with shape [1, N]
  • Outputs flat float arrays as comma-separated strings

Usage

Instantiate UInferenceInterfaceOnnx and pass it to a bridge or component via SetInferenceInterface().

After calling LoadModel(), you may begin using RunInference() with live observations.


Key Methods

bool LoadModel(const FString& ModelPath)

Initializes ONNX Runtime if not already active, then creates a session from the given model path.

  • Model must have input name "obs" and output name "actions"
  • Path can be absolute or project-relative
  • Returns true if the session loads successfully

FString RunInference(const TArray& Observation)

Performs inference using the loaded session:

  • Wraps input array into ONNX tensor format
  • Runs the model on CPU
  • Reads result from "actions" output tensor
  • Converts output array into a comma-separated FString

EXPECTED FORMAT: Input array must match the model's observation space. For example:

TArray<float> Observation = { 1.0f, 0.0f, -0.5f, 3.2f };

This array will be treated as a [1, 4] tensor and passed to ONNX Runtime.

The output will be formatted with a number of values equal to your policies action space. For example, if your action space has a size of 2:

"0.1,0.9"

bool IsModelLoaded() const

Returns true if the ONNX session is valid and ready.


Implementation Notes

  • Uses Ort::Env, Ort::Session, and Ort::SessionOptions to manage runtime state
  • Globals GEnv and GSessionOptions are shared across instances
  • Inference expects a 2D shape [1, N] where N = Observation.Num()
  • Output must be cast to float array and reshaped manually

⚠️ **GitHub.com Fallback** ⚠️