System UnrealPlugin InferenceInterfaces Onnx - kcccr123/ue-reinforcement-learning GitHub Wiki
UInferenceInterfaceOnnx
is a concrete implementation of the UInferenceInterface
using ONNX Runtime. It loads .onnx
files and performs inference given a vector of float observations, returning a serialized action string.
This class serves as the default backend for model inference in the framework.
ONNX (Open Neural Network Exchange) is a widely used standard format for trained neural networks. This implementation wraps ONNX Runtime and plugs directly into the RL bridge infrastructure.
Key features:
- Loads
.onnx
files from disk - Runs inference on CPU
- Designed for 1D float observations with shape
[1, N]
- Outputs flat float arrays as comma-separated strings
Instantiate UInferenceInterfaceOnnx
and pass it to a bridge or component via SetInferenceInterface()
.
After calling LoadModel()
, you may begin using RunInference()
with live observations.
Initializes ONNX Runtime if not already active, then creates a session from the given model path.
- Model must have input name
"obs"
and output name"actions"
- Path can be absolute or project-relative
- Returns
true
if the session loads successfully
Performs inference using the loaded session:
- Wraps input array into ONNX tensor format
- Runs the model on CPU
- Reads result from
"actions"
output tensor - Converts output array into a comma-separated
FString
EXPECTED FORMAT: Input array must match the model's observation space. For example:
TArray<float> Observation = { 1.0f, 0.0f, -0.5f, 3.2f };
This array will be treated as a [1, 4]
tensor and passed to ONNX Runtime.
The output will be formatted with a number of values equal to your policies action space. For example, if your action space has a size of 2:
"0.1,0.9"
Returns true
if the ONNX session is valid and ready.
- Uses
Ort::Env
,Ort::Session
, andOrt::SessionOptions
to manage runtime state - Globals
GEnv
andGSessionOptions
are shared across instances - Inference expects a 2D shape
[1, N]
whereN = Observation.Num()
- Output must be cast to float array and reshaped manually