diff --git a/src/Snap/Http/Server.hs b/src/Snap/Http/Server.hs index 28b9e32e..ab36bd0b 100644 --- a/src/Snap/Http/Server.hs +++ b/src/Snap/Http/Server.hs @@ -83,6 +83,7 @@ simpleHttpServe config handler = do ($ dat) (getStartupHook conf)) (runSnap handler) + (fromJust $ getMaxPOSTBodySize conf) -------------------------------------------------------------------------- mkStartupInfo sockets conf = diff --git a/src/Snap/Http/Server/Config.hs b/src/Snap/Http/Server/Config.hs index ef9865ec..9dbf3e11 100644 --- a/src/Snap/Http/Server/Config.hs +++ b/src/Snap/Http/Server/Config.hs @@ -35,6 +35,7 @@ module Snap.Http.Server.Config , getSSLPort , getVerbose , getStartupHook + , getMaxPOSTBodySize , setAccessLog , setBind @@ -53,6 +54,7 @@ module Snap.Http.Server.Config , setSSLPort , setVerbose , setStartupHook + , setMaxPOSTBodySize , StartupInfo , getStartupSockets , getStartupConfig diff --git a/src/Snap/Internal/Http/Server.hs b/src/Snap/Internal/Http/Server.hs index 8e28ff39..ea40f596 100644 --- a/src/Snap/Internal/Http/Server.hs +++ b/src/Snap/Internal/Http/Server.hs @@ -128,6 +128,7 @@ data ServerState = ServerState , _sessionPort :: SessionInfo , _logAccess :: Request -> Response -> IO () , _logError :: ByteString -> IO () + , _maxPOSTBodySize :: Int64 } @@ -137,10 +138,11 @@ runServerMonad :: ByteString -- ^ local host name -> (Request -> Response -> IO ()) -- ^ access log function -> (ByteString -> IO ()) -- ^ error log function -> ServerMonad a -- ^ monadic action to run + -> Int64 -- ^ maximum POST body size -> Iteratee ByteString IO a -runServerMonad lh s la le m = evalStateT m st +runServerMonad lh s la le m mpbs = evalStateT m st where - st = ServerState False lh s la le + st = ServerState False lh s la le mpbs ------------------------------------------------------------------------------ @@ -155,8 +157,9 @@ httpServe :: Int -- ^ default timeout -> Maybe (ByteString -> IO ()) -- ^ error log action -> ([Socket] -> IO ()) -- ^ initialisation -> ServerHandler -- ^ handler procedure + -> Int64 -- ^ maximum post body size -> IO () -httpServe defaultTimeout ports localHostname alog' elog' initial handler = +httpServe defaultTimeout ports localHostname alog' elog' initial handler mpbs = withSocketsDo $ spawnAll alog' elog' `catches` errorHandlers where @@ -207,7 +210,7 @@ httpServe defaultTimeout ports localHostname alog' elog' initial handler = let socks = map (\x -> case x of ListenHttp s -> s; ListenHttps s _ -> s) nports (simpleEventLoop defaultTimeout nports numCapabilities (logE elog) (initial socks) - $ runHTTP defaultTimeout alog elog handler localHostname) + $ runHTTP defaultTimeout alog elog handler localHostname mpbs) `finally` do logE elog "Server.httpServe: SHUTDOWN" @@ -273,6 +276,7 @@ runHTTP :: Int -- ^ default timeout -> Maybe (ByteString -> IO ()) -- ^ error logger -> ServerHandler -- ^ handler procedure -> ByteString -- ^ local host name + -> Int64 -- ^ maximum POST body size -> SessionInfo -- ^ session port information -> Enumerator ByteString IO () -- ^ read end of socket -> Iteratee ByteString IO () -- ^ write end of socket @@ -280,7 +284,7 @@ runHTTP :: Int -- ^ default timeout -- ^ sendfile end -> ((Int -> Int) -> IO ()) -- ^ timeout tickler -> IO () -runHTTP defaultTimeout alog elog handler lh sinfo readEnd writeEnd onSendFile +runHTTP defaultTimeout alog elog handler lh mpbs sinfo readEnd writeEnd onSendFile tickle = go `catches` [ Handler $ \(_ :: TerminatedBeforeHandlerException) -> do return () @@ -304,9 +308,11 @@ runHTTP defaultTimeout alog elog handler lh sinfo readEnd writeEnd onSendFile go = do buf <- allocBuffer 16384 - let iter1 = runServerMonad lh sinfo (logA alog) (logE elog) $ - httpSession defaultTimeout writeEnd buf - onSendFile tickle handler + let iter1 = runServerMonad lh sinfo (logA alog) (logE elog) + (httpSession defaultTimeout writeEnd buf + onSendFile tickle handler) + mpbs + let iter = iterateeDebugWrapper "httpSession iteratee" iter1 debug "runHTTP/go: prepping iteratee for start" @@ -615,19 +621,17 @@ receiveRequest writeEnd = do mbCT' = liftM trimIt mbCT doIt = mbCT' == Just "application/x-www-form-urlencoded" - maximumPOSTBodySize :: Int64 - maximumPOSTBodySize = 10*1024*1024 - getIt :: ServerMonad Request getIt = {-# SCC "receiveRequest/parseForm/getIt" #-} do debug "parseForm: got application/x-www-form-urlencoded" debug "parseForm: reading POST body" senum <- liftIO $ readIORef $ rqBody req + mpbs <- gets _maxPOSTBodySize let (SomeEnumerator enum) = senum consumeStep <- liftIO $ runIteratee consume step <- liftIO $ runIteratee $ - joinI $ takeNoMoreThan maximumPOSTBodySize consumeStep + joinI $ takeNoMoreThan mpbs consumeStep body <- liftM S.concat $ lift $ enum step let newParams = parseUrlEncoded body diff --git a/src/Snap/Internal/Http/Server/Config.hs b/src/Snap/Internal/Http/Server/Config.hs index 27c5701e..8f9d8238 100644 --- a/src/Snap/Internal/Http/Server/Config.hs +++ b/src/Snap/Internal/Http/Server/Config.hs @@ -20,6 +20,7 @@ import Control.Monad import qualified Data.ByteString.Char8 as B import Data.ByteString (ByteString) import Data.Char +import Data.Int import Data.Function import Data.List import Data.Maybe @@ -78,24 +79,25 @@ instance Show ConfigLog where -- Any fields which are unspecified in the 'Config' passed to 'httpServe' (and -- this is the norm) are filled in with default values from 'defaultConfig'. data Config m a = Config - { hostname :: Maybe ByteString - , accessLog :: Maybe ConfigLog - , errorLog :: Maybe ConfigLog - , locale :: Maybe String - , port :: Maybe Int - , bind :: Maybe ByteString - , sslport :: Maybe Int - , sslbind :: Maybe ByteString - , sslcert :: Maybe FilePath - , sslkey :: Maybe FilePath - , compression :: Maybe Bool - , verbose :: Maybe Bool - , errorHandler :: Maybe (SomeException -> m ()) - , defaultTimeout :: Maybe Int - , other :: Maybe a - , backend :: Maybe ConfigBackend - , proxyType :: Maybe ProxyType - , startupHook :: Maybe (StartupInfo m a -> IO ()) + { hostname :: Maybe ByteString + , accessLog :: Maybe ConfigLog + , errorLog :: Maybe ConfigLog + , locale :: Maybe String + , port :: Maybe Int + , bind :: Maybe ByteString + , sslport :: Maybe Int + , sslbind :: Maybe ByteString + , sslcert :: Maybe FilePath + , sslkey :: Maybe FilePath + , compression :: Maybe Bool + , verbose :: Maybe Bool + , errorHandler :: Maybe (SomeException -> m ()) + , defaultTimeout :: Maybe Int + , other :: Maybe a + , backend :: Maybe ConfigBackend + , proxyType :: Maybe ProxyType + , startupHook :: Maybe (StartupInfo m a -> IO ()) + , maxPOSTBodySize :: Maybe Int64 } #if MIN_VERSION_base(4,7,0) deriving (Typeable) @@ -148,45 +150,47 @@ emptyConfig = mempty ------------------------------------------------------------------------------ instance Monoid (Config m a) where mempty = Config - { hostname = Nothing - , accessLog = Nothing - , errorLog = Nothing - , locale = Nothing - , port = Nothing - , bind = Nothing - , sslport = Nothing - , sslbind = Nothing - , sslcert = Nothing - , sslkey = Nothing - , compression = Nothing - , verbose = Nothing - , errorHandler = Nothing - , defaultTimeout = Nothing - , other = Nothing - , backend = Nothing - , proxyType = Nothing - , startupHook = Nothing + { hostname = Nothing + , accessLog = Nothing + , errorLog = Nothing + , locale = Nothing + , port = Nothing + , bind = Nothing + , sslport = Nothing + , sslbind = Nothing + , sslcert = Nothing + , sslkey = Nothing + , compression = Nothing + , verbose = Nothing + , errorHandler = Nothing + , defaultTimeout = Nothing + , other = Nothing + , backend = Nothing + , proxyType = Nothing + , startupHook = Nothing + , maxPOSTBodySize = Nothing } a `mappend` b = Config - { hostname = ov hostname - , accessLog = ov accessLog - , errorLog = ov errorLog - , locale = ov locale - , port = ov port - , bind = ov bind - , sslport = ov sslport - , sslbind = ov sslbind - , sslcert = ov sslcert - , sslkey = ov sslkey - , compression = ov compression - , verbose = ov verbose - , errorHandler = ov errorHandler - , defaultTimeout = ov defaultTimeout - , other = ov other - , backend = ov backend - , proxyType = ov proxyType - , startupHook = ov startupHook + { hostname = ov hostname + , accessLog = ov accessLog + , errorLog = ov errorLog + , locale = ov locale + , port = ov port + , bind = ov bind + , sslport = ov sslport + , sslbind = ov sslbind + , sslcert = ov sslcert + , sslkey = ov sslkey + , compression = ov compression + , verbose = ov verbose + , errorHandler = ov errorHandler + , defaultTimeout = ov defaultTimeout + , other = ov other + , backend = ov backend + , proxyType = ov proxyType + , startupHook = ov startupHook + , maxPOSTBodySize = ov maxPOSTBodySize } where ov f = getLast $! (mappend `on` (Last . f)) a b @@ -212,18 +216,19 @@ instance (Typeable1 m) => Typeable1 (Config m) where -- | These are the default values for the options defaultConfig :: MonadSnap m => Config m a defaultConfig = mempty - { hostname = Just "localhost" - , accessLog = Just $ ConfigFileLog "log/access.log" - , errorLog = Just $ ConfigFileLog "log/error.log" - , locale = Just "en_US" - , compression = Just True - , verbose = Just True - , errorHandler = Just defaultErrorHandler - , bind = Just "0.0.0.0" - , sslbind = Just "0.0.0.0" - , sslcert = Just "cert.pem" - , sslkey = Just "key.pem" - , defaultTimeout = Just 60 + { hostname = Just "localhost" + , accessLog = Just $ ConfigFileLog "log/access.log" + , errorLog = Just $ ConfigFileLog "log/error.log" + , locale = Just "en_US" + , compression = Just True + , verbose = Just True + , errorHandler = Just defaultErrorHandler + , bind = Just "0.0.0.0" + , sslbind = Just "0.0.0.0" + , sslcert = Just "cert.pem" + , sslkey = Just "key.pem" + , defaultTimeout = Just 60 + , maxPOSTBodySize = Just (10*1024*1024) } @@ -304,6 +309,8 @@ getProxyType = proxyType getStartupHook :: Config m a -> Maybe (StartupInfo m a -> IO ()) getStartupHook = startupHook +getMaxPOSTBodySize :: Config m a -> Maybe Int64 +getMaxPOSTBodySize = maxPOSTBodySize ------------------------------------------------------------------------------ setHostname :: ByteString -> Config m a -> Config m a @@ -360,6 +367,8 @@ setProxyType x c = c { proxyType = Just x } setStartupHook :: (StartupInfo m a -> IO ()) -> Config m a -> Config m a setStartupHook x c = c { startupHook = Just x } +setMaxPOSTBodySize :: Int64 -> Config m a -> Config m a +setMaxPOSTBodySize x c = c { maxPOSTBodySize = Just x } ------------------------------------------------------------------------------