diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index fa53c1818f5..66d61c5c5e3 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -64,6 +64,7 @@ # Test case definitions for quantizer annotation tests. # Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs) # Adding a new quantizer test only requires adding a tuple to this list. +# Note: Use None in expected_input_qspecs to skip comparison for that input (e.g., for DerivedQuantizationSpec). QUANTIZER_ANNOTATION_TEST_CASES: list[ tuple[ str, @@ -71,7 +72,7 @@ CadenceQuantizer, OpOverload, QuantizationSpec, - list[QuantizationSpec], + list[QuantizationSpec | None], ] ] = [ ( @@ -192,6 +193,16 @@ # For relu: only input_activation [qconfig_A8W8.input_activation], ), + ( + "default_addmm_A8W8", + lambda self: self._build_addmm_graph(), + CadenceDefaultQuantizer(), + torch.ops.aten.addmm.default, + qconfig_A8W8.output_activation, + # For addmm: [bias (DerivedQuantizationSpec), mat1, mat2] + # Use None to skip comparison for bias since it's a DerivedQuantizationSpec + [None, qconfig_A8W8.input_activation, qconfig_A8W8.weight], + ), ] # Derive the set of tested quantizer classes from the test cases. @@ -408,6 +419,31 @@ def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node") return gm, relu_nodes[0] + def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: + """Build a simple graph with an addmm operation.""" + builder = GraphBuilder() + # addmm: bias + (mat1 @ mat2) + # args: (bias, mat1, mat2) + bias = builder.placeholder("bias", torch.randn(5)) + mat1 = builder.placeholder("mat1", torch.randn(1, 10)) + mat2 = builder.placeholder("mat2", torch.randn(10, 5)) + addmm = builder.call_operator( + op=torch.ops.aten.addmm.default, + args=(bias, mat1, mat2), + meta=NodeMetadata( + {"source_fn_stack": [("addmm", torch.ops.aten.addmm.default)]} + ), + ) + builder.output([addmm]) + gm = builder.get_graph_module() + + addmm_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.addmm.default, + ) + self.assertEqual(len(addmm_nodes), 1, "Should find exactly one addmm node") + return gm, addmm_nodes[0] + @parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES) def test_quantizer_annotation( self, @@ -416,7 +452,7 @@ def test_quantizer_annotation( quantizer: CadenceQuantizer, target: OpOverload, expected_output_qspec: QuantizationSpec, - expected_input_qspecs: list[QuantizationSpec], + expected_input_qspecs: list[QuantizationSpec | None], ) -> None: """Parameterized test for quantizer annotations.""" gm, op_node = graph_builder_fn(self) @@ -431,21 +467,24 @@ def test_quantizer_annotation( # Verify input annotations self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs)) - for i, (input_node, input_qspec) in enumerate( - annotation.input_qspec_map.items() - ): - expected_arg = op_node.args[i] - assert isinstance(expected_arg, torch.fx.Node) - self.assertEqual( - input_node, - expected_arg, - f"Input node mismatch at index {i}", - ) - self.assertEqual( - input_qspec, - expected_input_qspecs[i], - f"Input qspec mismatch at index {i}", + for input_node, input_qspec in annotation.input_qspec_map.items(): + # Find the index of this input node in the op's args + arg_index = None + for i, arg in enumerate(op_node.args): + if arg is input_node: + arg_index = i + break + self.assertIsNotNone( + arg_index, + f"Input node {input_node} not found in op_node.args", ) + # Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec) + if expected_input_qspecs[arg_index] is not None: + self.assertEqual( + input_qspec, + expected_input_qspecs[arg_index], + f"Input qspec mismatch at arg index {arg_index}", + ) def test_all_quantizers_have_annotation_tests(self) -> None: """Ensure every CadenceQuantizer subclass is either tested or explicitly excluded."""