Skip to content

Commit 7442007

Browse files
authored
Merge pull request #119 from awslabs/spring
Fixing forward and include filter issues with Spring - this closes all issues discovered while investigating #105.
2 parents 7d4718b + e345b1d commit 7442007

File tree

7 files changed

+64
-11
lines changed

7 files changed

+64
-11
lines changed

aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ public abstract class AwsHttpServletRequest implements HttpServletRequest {
101101

102102
@Override
103103
public String getRequestedSessionId() {
104-
throw new UnsupportedOperationException();
104+
return null;
105105
}
106106

107107

aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsLambdaServletContainerHandler.java

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,16 @@
2222
import org.slf4j.Logger;
2323
import org.slf4j.LoggerFactory;
2424

25+
import javax.servlet.DispatcherType;
2526
import javax.servlet.FilterChain;
2627
import javax.servlet.Servlet;
2728
import javax.servlet.ServletContext;
2829
import javax.servlet.ServletException;
30+
import javax.servlet.ServletResponse;
2931
import javax.servlet.http.HttpServletRequest;
3032
import javax.servlet.http.HttpServletResponse;
33+
import javax.servlet.http.HttpServletResponseWrapper;
34+
3135
import java.io.IOException;
3236

3337
/**
@@ -85,7 +89,7 @@ protected AwsLambdaServletContainerHandler(RequestReader<RequestType, ContainerR
8589
public void forward(ContainerRequestType servletRequest, ContainerResponseType servletResponse)
8690
throws ServletException, IOException {
8791
try {
88-
handleRequest(servletRequest, servletResponse, lambdaContext);
92+
handleRequest(servletRequest, (ContainerResponseType)getServletResponse(servletResponse), lambdaContext);
8993
} catch (Exception e) {
9094
log.error("Could not forward request", e);
9195
throw new ServletException(e);
@@ -103,13 +107,28 @@ public void forward(ContainerRequestType servletRequest, ContainerResponseType s
103107
public void include(ContainerRequestType servletRequest, ContainerResponseType servletResponse)
104108
throws ServletException, IOException {
105109
try {
106-
handleRequest(servletRequest, servletResponse, lambdaContext);
110+
handleRequest(servletRequest, (ContainerResponseType)getServletResponse(servletResponse), lambdaContext);
107111
} catch (Exception e) {
108112
log.error("Could not include request", e);
109113
throw new ServletException(e);
110114
}
111115
}
112116

117+
private HttpServletResponse getServletResponse(ContainerResponseType resp) {
118+
if (HttpServletResponseWrapper.class.isAssignableFrom(resp.getClass())) {
119+
ServletResponse servletResp = ((HttpServletResponseWrapper)resp).getResponse();
120+
assert servletResp instanceof HttpServletResponse : servletResp.getClass();
121+
return (HttpServletResponse)servletResp;
122+
}
123+
124+
if (HttpServletResponse.class.isAssignableFrom(resp.getClass())) {
125+
return resp;
126+
}
127+
128+
129+
throw new UnsupportedOperationException("Response type of " + resp.getClass().getName() + " is not supported");
130+
}
131+
113132

114133
/**
115134
* You can use the <code>onStartup</code> to intercept the ServletContext as the Spring application is
@@ -189,6 +208,11 @@ protected FilterChain getFilterChain(ContainerRequestType req, Servlet servlet)
189208
protected void doFilter(ContainerRequestType request, ContainerResponseType response, Servlet servlet) throws IOException, ServletException {
190209
FilterChain chain = getFilterChain(request, servlet);
191210
chain.doFilter(request, response);
211+
212+
// if for some reason the response wasn't flushed yet, we force it here.
213+
if (request.getDispatcherType() != DispatcherType.FORWARD && request.getDispatcherType() != DispatcherType.INCLUDE && !response.isCommitted()) {
214+
response.flushBuffer();
215+
}
192216
}
193217

194218
//-------------------------------------------------------------

aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyRequestDispatcher.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package com.amazonaws.serverless.proxy.internal.servlet;
22

33

4+
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
5+
46
import javax.servlet.DispatcherType;
57
import javax.servlet.RequestDispatcher;
68
import javax.servlet.ServletException;
@@ -61,6 +63,7 @@ public void forward(ServletRequest servletRequest, ServletResponse servletRespon
6163
((AwsProxyHttpServletRequest) servletRequest).setDispatcherType(DispatcherType.FORWARD);
6264
((AwsProxyHttpServletRequest) servletRequest).getAwsProxyRequest().setPath(dispatchPath);
6365

66+
assert servletResponse instanceof HttpServletResponse : servletResponse.getClass();
6467
lambdaContainerHandler.forward((HttpServletRequest)servletRequest, (HttpServletResponse)servletResponse);
6568
}
6669

@@ -80,6 +83,7 @@ public void include(ServletRequest servletRequest, ServletResponse servletRespon
8083
((AwsProxyHttpServletRequest) servletRequest).setDispatcherType(DispatcherType.INCLUDE);
8184
((AwsProxyHttpServletRequest) servletRequest).getAwsProxyRequest().setPath(dispatchPath);
8285

86+
assert servletResponse instanceof HttpServletResponse : servletResponse.getClass();
8387
lambdaContainerHandler.include((HttpServletRequest)servletRequest, (HttpServletResponse)servletResponse);
8488
}
8589
}

aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainHolder.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import org.slf4j.LoggerFactory;
1818

1919
import javax.servlet.*;
20-
import javax.servlet.http.HttpServletRequest;
2120

2221
import java.io.IOException;
2322
import java.util.ArrayList;
@@ -86,11 +85,6 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo
8685
log.debug("Executed {}: filter {}-{}", servletRequest.getDispatcherType(),
8786
currentFilter, holder.getFilterName());
8887
}
89-
90-
// if for some reason the response wasn't flushed yet, we force it here.
91-
if (!servletResponse.isCommitted()) {
92-
servletResponse.flushBuffer();
93-
}
9488
}
9589

9690

aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletResponseTest.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818

1919
public class AwsHttpServletResponseTest {
20+
// we use this int to compare the cookie expiration time in the tests. The date we generate to compare to
21+
// may be slight off compared to the date generated during the request processing
22+
private static final int COOKIE_GRACE_COMPARE_MILLIS = 2000;
2023
private static final String COOKIE_NAME = "session_id";
2124
private static final String COOKIE_VALUE = "123";
2225
private static final String COOKIE_PATH = "/api";
@@ -121,8 +124,9 @@ public void cookie_addCookie_positiveMaxAgeExpiresDate() {
121124
Calendar expiration = getExpires(cookieHeader);
122125
System.out.println("Cookie date: " + dateFormat.format(expiration.getTime()));
123126
System.out.println("Test date: " + dateFormat.format(testExpiration.getTime()));
124-
// we need to compare strings because the millis time will be off
125-
assertEquals(dateFormat.format(testExpiration.getTime()), dateFormat.format(expiration.getTime()));
127+
128+
long dateDiff = testExpiration.getTimeInMillis() - expiration.getTimeInMillis();
129+
assertTrue(Math.abs(dateDiff) < COOKIE_GRACE_COMPARE_MILLIS);
126130
}
127131

128132
@Test

aws-serverless-java-container-jersey/src/main/java/com/amazonaws/serverless/proxy/jersey/JerseyServletResponseWriter.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class JerseyServletResponseWriter
6060
* @param resp The current ServletResponse from the container
6161
*/
6262
public JerseyServletResponseWriter(ServletResponse resp, CountDownLatch latch) {
63+
assert resp instanceof HttpServletResponse;
6364
servletResponse = (HttpServletResponse)resp;
6465
jerseyLatch = latch;
6566
}

aws-serverless-java-container-spring/src/test/java/com/amazonaws/serverless/proxy/spring/SpringBootAppTest.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@
99
import com.amazonaws.serverless.proxy.spring.springbootapp.LambdaHandler;
1010
import com.amazonaws.serverless.proxy.spring.springbootapp.TestController;
1111

12+
import com.fasterxml.jackson.core.JsonProcessingException;
13+
import com.fasterxml.jackson.databind.JsonNode;
1214
import com.fasterxml.jackson.databind.ObjectMapper;
1315
import org.junit.Test;
1416

1517
import java.io.IOException;
18+
import java.util.Map;
1619

1720
import static org.junit.Assert.*;
1821

@@ -30,6 +33,29 @@ public void testMethod_springSecurity_doesNotThrowException() {
3033
validateSingleValueModel(resp, TestController.TEST_VALUE);
3134
}
3235

36+
@Test
37+
public void defaultError_requestForward_springBootForwardsToDefaultErrorPage() {
38+
AwsProxyRequest req = new AwsProxyRequestBuilder("/test2", "GET").build();
39+
AwsProxyResponse resp = handler.handleRequest(req, context);
40+
assertNotNull(resp);
41+
assertEquals(404, resp.getStatusCode());
42+
assertNotNull(resp.getHeaders());
43+
assertTrue(resp.getHeaders().containsKey("Content-Type"));
44+
assertEquals("application/json;charset=UTF-8", resp.getHeaders().get("Content-Type"));
45+
try {
46+
JsonNode errorData = mapper.readTree(resp.getBody());
47+
assertNotNull(errorData.findValue("status"));
48+
assertEquals(404, errorData.findValue("status").asInt());
49+
assertNotNull(errorData.findValue("message"));
50+
assertEquals("No message available", errorData.findValue("message").asText());
51+
52+
} catch (IOException e) {
53+
e.printStackTrace();
54+
fail();
55+
}
56+
57+
}
58+
3359
private void validateSingleValueModel(AwsProxyResponse output, String value) {
3460
try {
3561
SingleValueModel response = mapper.readValue(output.getBody(), SingleValueModel.class);

0 commit comments

Comments
 (0)