원클릭으로
add-shape-inference
// Add or update type and shape inference for an ONNX operator. Use when asked to implement TypeAndShapeInferenceFunction, propagate shapes, add shape inference tests, fix shape inference bugs, or handle broadcasting logic.
// Add or update type and shape inference for an ONNX operator. Use when asked to implement TypeAndShapeInferenceFunction, propagate shapes, add shape inference tests, fix shape inference bugs, or handle broadcasting logic.
| name | add-shape-inference |
| description | Add or update type and shape inference for an ONNX operator. Use when asked to implement TypeAndShapeInferenceFunction, propagate shapes, add shape inference tests, fix shape inference bugs, or handle broadcasting logic. |
See also: docs/ShapeInference.md
| Component | File |
|---|---|
| Inference function | onnx/defs/<domain>/defs.cc (inline with schema) |
| Utility functions | onnx/defs/shape_inference.h |
| Tests | onnx/test/shape_inference_test.py |
Type inference (element type) is often handled automatically by type constraints. When "T" is shared between input and output, the framework infers output type automatically.
However, many existing ops still explicitly call propagateElemTypeFromInputToOutput as a best practice for robustness.
Explicit type inference logic is only needed when:
Cast)Applies only to variadic (repeated) inputs/outputs:
Loop/Scan. The inference method must explicitly propagate types for each argument..TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
static void InferShapeForBinaryOp(InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (hasNInputShapes(ctx, 2))
bidirectionalBroadcastShapeInference(
ctx.getInputType(0)->tensor_type().shape(),
ctx.getInputType(1)->tensor_type().shape(),
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
}
static void InferShapeForTranspose(InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 1)) return;
auto input_shape = ctx.getInputType(0)->tensor_type().shape();
int rank = input_shape.dim_size();
std::vector<int64_t> perm;
getRepeatedAttribute(ctx, "perm", perm);
auto* output_shape = getOutputShape(ctx, 0);
for (int i = 0; i < rank; ++i) {
*output_shape->add_dim() = input_shape.dim(perm[i]);
}
}
| Function | Purpose |
|---|---|
propagateElemTypeFromInputToOutput(ctx, in, out) | Copy element type |
propagateShapeFromInputToOutput(ctx, in, out) | Copy entire shape |
propagateShapeAndTypeFromFirstInput(ctx) | Both type and shape from input 0 |
hasNInputShapes(ctx, n) | Check first n inputs have shapes |
getOutputShape(ctx, out) | Get mutable output shape |
bidirectionalBroadcastShapeInference(L, R, out) | Numpy broadcasting |
getRepeatedAttribute(ctx, "name", vec) | Get repeated attr values |
getAttribute(ctx, "name", default) | Get single attr value |
mergeInDimensionInfo(src, dst, dim_idx) | Merge dimension info |
fail_shape_inference("msg") | Throw inference error |
Dim operator*(const Dim& a, const Dim& b);
Dim operator*(const Dim& a, int64_t val);
Dim operator/(const Dim& a, int64_t divisor);
Dim multiplyDims(const TensorShapeProto& shape, int from, int upto);
The _make_graph / _assert_inferred helpers are right for parameterized op-version sweeps:
@parameterized.expand(all_versions_for("OpName"))
def test_opname(self, _, version) -> None:
graph = self._make_graph(
[("X", TensorProto.FLOAT, (2, 3, 4))],
[make_node("OpName", ["X"], ["Y"], attr_name=attr_value)],
[],
)
self._assert_inferred(
graph,
[make_tensor_value_info("Y", TensorProto.FLOAT, expected_shape)],
opset_imports=[helper.make_opsetid(ONNX_DOMAIN, version)],
)
For one-off fixtures — anything with attributes, body subgraphs, or non-trivial type info — prefer the onnxtxt skill's parser-based fixtures (it also covers the C++ unk__* materialization gotcha for free dims).
Cover: known shapes, partial shapes (None), rank inference, error cases, broadcasting, attribute-dependent shapes.
Define inference functions as separate named functions rather than inline lambdas. The macro expansion makes breakpoints on inline lambdas unreliable.
Short one-liners (e.g., propagateShapeAndTypeFromFirstInput) are fine as direct references.
hasNInputShapes(ctx, n) before accessing shapeshas_dim_value() before using dim_value()dim_param) when possiblepytest onnx/test/shape_inference_test.py -k "test_opname" -x
python onnx/defs/gen_doc.py
lintrunner -a --output oneline
Add a function body definition to an ONNX operator, defining how it decomposes into simpler ops. Use when asked to make an op decomposable, add a FunctionBody, implement SetContextDependentFunctionBodyBuilder, or express an op in terms of other ONNX operators.
Add a new ONNX operator or update an existing operator to a new opset version. Use when asked to define an operator schema, register an op, add inputs/outputs/attributes to an op, move an op to old.cc, or bump an op's opset version.
Read or write ONNX text format ("onnxtxt"). Use when authoring `.FunctionBody(R"ONNX(...)")` blocks, writing tests with `onnx.parser.parse_model` / `parse_graph`, using the C++ `OnnxParser`, debugging parser errors, or interpreting `Constant <value = ...>` and body-subgraph syntax.