diff --git a/framework/codemodder-base/src/main/java/io/codemodder/remediation/xss/ResponseEntityFixStrategy.java b/framework/codemodder-base/src/main/java/io/codemodder/remediation/xss/ResponseEntityFixStrategy.java index 978b2d5ad..0e9b41364 100644 --- a/framework/codemodder-base/src/main/java/io/codemodder/remediation/xss/ResponseEntityFixStrategy.java +++ b/framework/codemodder-base/src/main/java/io/codemodder/remediation/xss/ResponseEntityFixStrategy.java @@ -2,7 +2,9 @@ import com.github.javaparser.ast.CompilationUnit; import com.github.javaparser.ast.Node; +import com.github.javaparser.ast.expr.Expression; import com.github.javaparser.ast.expr.ObjectCreationExpr; +import com.github.javaparser.resolution.types.ResolvedType; import io.codemodder.remediation.RemediationStrategy; import io.codemodder.remediation.SuccessOrReason; import java.util.Optional; @@ -32,6 +34,17 @@ static boolean match(final Node node) { "ResponseEntity".equals(c.getTypeAsString()) || c.getTypeAsString().startsWith("ResponseEntity<")) .filter(c -> !c.getArguments().isEmpty()) + .filter( + c -> { + Expression firstArg = c.getArguments().getFirst().get(); + try { + ResolvedType resolvedType = firstArg.calculateResolvedType(); + return "java.lang.String".equals(resolvedType.describe()); + } catch (Exception e) { + // this is expected often, and indicates its a non-String type anyway + return false; + } + }) .isPresent(); } } diff --git a/framework/codemodder-base/src/test/java/io/codemodder/remediation/xss/ResponseEntityFixStrategyTest.java b/framework/codemodder-base/src/test/java/io/codemodder/remediation/xss/ResponseEntityFixStrategyTest.java index 07e8450cd..fdd8878a7 100644 --- a/framework/codemodder-base/src/test/java/io/codemodder/remediation/xss/ResponseEntityFixStrategyTest.java +++ b/framework/codemodder-base/src/test/java/io/codemodder/remediation/xss/ResponseEntityFixStrategyTest.java @@ -2,13 +2,15 @@ import static org.assertj.core.api.Assertions.assertThat; -import com.github.javaparser.StaticJavaParser; +import com.github.javaparser.JavaParser; import com.github.javaparser.ast.CompilationUnit; import com.github.javaparser.printer.lexicalpreservation.LexicalPreservingPrinter; import io.codemodder.CodemodFileScanningResult; import io.codemodder.codetf.DetectorRule; +import io.codemodder.javaparser.JavaParserFactory; import io.codemodder.remediation.FixCandidateSearcher; import io.codemodder.remediation.SearcherStrategyRemediator; +import java.io.IOException; import java.util.List; import java.util.Optional; import java.util.stream.Stream; @@ -21,10 +23,12 @@ final class ResponseEntityFixStrategyTest { private ResponseEntityFixStrategy fixer; private DetectorRule rule; + private JavaParser parser; @BeforeEach - void setup() { + void setup() throws IOException { this.fixer = new ResponseEntityFixStrategy(); + this.parser = JavaParserFactory.newFactory().create(List.of()); this.rule = new DetectorRule("xss", "XSS", null); } @@ -67,7 +71,7 @@ ResponseEntity should_be_fixed(String s) { @ParameterizedTest @MethodSource("fixableSamples") void it_fixes_obvious_response_write_methods(final String beforeCode, final String afterCode) { - CompilationUnit cu = StaticJavaParser.parse(beforeCode); + CompilationUnit cu = parser.parse(beforeCode).getResult().orElseThrow(); LexicalPreservingPrinter.setup(cu); var result = scanAndFix(cu, 3); @@ -100,7 +104,7 @@ private CodemodFileScanningResult scanAndFix(final CompilationUnit cu, final int @ParameterizedTest @MethodSource("unfixableSamples") void it_does_not_fix_unfixable_samples(final String beforeCode, final int line) { - CompilationUnit cu = StaticJavaParser.parse(beforeCode); + CompilationUnit cu = parser.parse(beforeCode).getResult().orElseThrow(); LexicalPreservingPrinter.setup(cu); var result = scanAndFix(cu, line); assertThat(result.changes()).isEmpty(); @@ -110,13 +114,24 @@ private static Stream unfixableSamples() { return Stream.of( // this is not a ResponseEntity, shouldn't touch it Arguments.of( + // this is not a ResponseEntity, shouldn't touch it """ class Samples { - String should_be_fixed(String s) { + String should_not_be_fixed(String s) { return new NotResponseEntity(s, HttpStatus.OK); } } """, + 3), + Arguments.of( + // this is not a String, shouldn't touch it + """ + class Samples { + String should_not_be_fixed(BodyType s) { + return new ResponseEntity(s, HttpStatus.OK); + } + } + """, 3)); } }