{-# LANGUAGE ScopedTypeVariables #-}
-- | Support for making connections via the OpenSSL library.
module Network.HTTP.Client.OpenSSL
    ( opensslManagerSettings
    -- , defaultMakeContext
    , withOpenSSL
    ) where

import Network.HTTP.Client
import Network.HTTP.Client.Internal
import Control.Exception
import Network.Socket.ByteString (sendAll, recv)
import OpenSSL
import qualified Data.ByteString as S
import qualified Network.Socket as N
import qualified OpenSSL.Session       as SSL

-- | Note that it is the caller's responsibility to pass in an appropriate
-- context. Future versions of http-client-openssl will hopefully include a
-- sane, safe default.
opensslManagerSettings :: IO SSL.SSLContext -> ManagerSettings
opensslManagerSettings :: IO SSLContext -> ManagerSettings
opensslManagerSettings mkContext :: IO SSLContext
mkContext = ManagerSettings
defaultManagerSettings
    { managerTlsConnection :: IO (Maybe HostAddress -> String -> Int -> IO Connection)
managerTlsConnection = do
        SSLContext
ctx <- IO SSLContext
mkContext
        (Maybe HostAddress -> String -> Int -> IO Connection)
-> IO (Maybe HostAddress -> String -> Int -> IO Connection)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Maybe HostAddress -> String -> Int -> IO Connection)
 -> IO (Maybe HostAddress -> String -> Int -> IO Connection))
-> (Maybe HostAddress -> String -> Int -> IO Connection)
-> IO (Maybe HostAddress -> String -> Int -> IO Connection)
forall a b. (a -> b) -> a -> b
$ \_ha :: Maybe HostAddress
_ha host' :: String
host' port' :: Int
port' -> do
            -- Copied/modified from openssl-streams
            let hints :: AddrInfo
hints      = AddrInfo
N.defaultHints
                                { addrFlags :: [AddrInfoFlag]
N.addrFlags      = [AddrInfoFlag
N.AI_ADDRCONFIG, AddrInfoFlag
N.AI_NUMERICSERV]
                                , addrFamily :: Family
N.addrFamily     = Family
N.AF_INET
                                , addrSocketType :: SocketType
N.addrSocketType = SocketType
N.Stream
                                }
            (addrInfo :: AddrInfo
addrInfo:_) <- Maybe AddrInfo -> Maybe String -> Maybe String -> IO [AddrInfo]
N.getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (String -> Maybe String
forall a. a -> Maybe a
Just String
host') (String -> Maybe String
forall a. a -> Maybe a
Just (String -> Maybe String) -> String -> Maybe String
forall a b. (a -> b) -> a -> b
$ Int -> String
forall a. Show a => a -> String
show Int
port')

            let family :: Family
family     = AddrInfo -> Family
N.addrFamily AddrInfo
addrInfo
            let socketType :: SocketType
socketType = AddrInfo -> SocketType
N.addrSocketType AddrInfo
addrInfo
            let protocol :: ProtocolNumber
protocol   = AddrInfo -> ProtocolNumber
N.addrProtocol AddrInfo
addrInfo
            let address :: SockAddr
address    = AddrInfo -> SockAddr
N.addrAddress AddrInfo
addrInfo

            IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Connection) -> IO Connection
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (Family -> SocketType -> ProtocolNumber -> IO Socket
N.socket Family
family SocketType
socketType ProtocolNumber
protocol) (Socket -> IO ()
N.close)
                ((Socket -> IO Connection) -> IO Connection)
-> (Socket -> IO Connection) -> IO Connection
forall a b. (a -> b) -> a -> b
$ \sock :: Socket
sock -> do
                    Socket -> SockAddr -> IO ()
N.connect Socket
sock SockAddr
address
                    SSL
ssl <- SSLContext -> Socket -> IO SSL
SSL.connection SSLContext
ctx Socket
sock
                    SSL -> String -> IO ()
SSL.setTlsextHostName SSL
ssl String
host'
                    SSL -> IO ()
SSL.connect SSL
ssl
                    IO ByteString -> (ByteString -> IO ()) -> IO () -> IO Connection
makeConnection
                        (SSL -> Int -> IO ByteString
SSL.read SSL
ssl 32752 IO ByteString
-> (ConnectionAbruptlyTerminated -> IO ByteString) -> IO ByteString
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` \(ConnectionAbruptlyTerminated
_ :: SSL.ConnectionAbruptlyTerminated) -> ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty)
                        (SSL -> ByteString -> IO ()
SSL.write SSL
ssl)
                        (Socket -> IO ()
N.close Socket
sock)
    , managerTlsProxyConnection :: IO
  (ByteString
   -> (Connection -> IO ())
   -> String
   -> Maybe HostAddress
   -> String
   -> Int
   -> IO Connection)
managerTlsProxyConnection = do
        SSLContext
ctx <- IO SSLContext
mkContext
        (ByteString
 -> (Connection -> IO ())
 -> String
 -> Maybe HostAddress
 -> String
 -> Int
 -> IO Connection)
-> IO
     (ByteString
      -> (Connection -> IO ())
      -> String
      -> Maybe HostAddress
      -> String
      -> Int
      -> IO Connection)
forall (m :: * -> *) a. Monad m => a -> m a
return ((ByteString
  -> (Connection -> IO ())
  -> String
  -> Maybe HostAddress
  -> String
  -> Int
  -> IO Connection)
 -> IO
      (ByteString
       -> (Connection -> IO ())
       -> String
       -> Maybe HostAddress
       -> String
       -> Int
       -> IO Connection))
-> (ByteString
    -> (Connection -> IO ())
    -> String
    -> Maybe HostAddress
    -> String
    -> Int
    -> IO Connection)
-> IO
     (ByteString
      -> (Connection -> IO ())
      -> String
      -> Maybe HostAddress
      -> String
      -> Int
      -> IO Connection)
forall a b. (a -> b) -> a -> b
$ \connstr :: ByteString
connstr checkConn :: Connection -> IO ()
checkConn serverName :: String
serverName _ha :: Maybe HostAddress
_ha host' :: String
host' port' :: Int
port' -> do
            -- Copied/modified from openssl-streams
            let hints :: AddrInfo
hints      = AddrInfo
N.defaultHints
                                { addrFlags :: [AddrInfoFlag]
N.addrFlags      = [AddrInfoFlag
N.AI_ADDRCONFIG, AddrInfoFlag
N.AI_NUMERICSERV]
                                , addrFamily :: Family
N.addrFamily     = Family
N.AF_INET
                                , addrSocketType :: SocketType
N.addrSocketType = SocketType
N.Stream
                                }
            (addrInfo :: AddrInfo
addrInfo:_) <- Maybe AddrInfo -> Maybe String -> Maybe String -> IO [AddrInfo]
N.getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (String -> Maybe String
forall a. a -> Maybe a
Just String
host') (String -> Maybe String
forall a. a -> Maybe a
Just (String -> Maybe String) -> String -> Maybe String
forall a b. (a -> b) -> a -> b
$ Int -> String
forall a. Show a => a -> String
show Int
port')

            let family :: Family
family     = AddrInfo -> Family
N.addrFamily AddrInfo
addrInfo
            let socketType :: SocketType
socketType = AddrInfo -> SocketType
N.addrSocketType AddrInfo
addrInfo
            let protocol :: ProtocolNumber
protocol   = AddrInfo -> ProtocolNumber
N.addrProtocol AddrInfo
addrInfo
            let address :: SockAddr
address    = AddrInfo -> SockAddr
N.addrAddress AddrInfo
addrInfo

            IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Connection) -> IO Connection
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (Family -> SocketType -> ProtocolNumber -> IO Socket
N.socket Family
family SocketType
socketType ProtocolNumber
protocol) (Socket -> IO ()
N.close)
                ((Socket -> IO Connection) -> IO Connection)
-> (Socket -> IO Connection) -> IO Connection
forall a b. (a -> b) -> a -> b
$ \sock :: Socket
sock -> do
                    Socket -> SockAddr -> IO ()
N.connect Socket
sock SockAddr
address
                    Connection
conn <- IO ByteString -> (ByteString -> IO ()) -> IO () -> IO Connection
makeConnection
                            (Socket -> Int -> IO ByteString
recv Socket
sock 32752)
                            (Socket -> ByteString -> IO ()
sendAll Socket
sock)
                            (() -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
                    Connection -> ByteString -> IO ()
connectionWrite Connection
conn ByteString
connstr
                    Connection -> IO ()
checkConn Connection
conn
                    SSL
ssl <- SSLContext -> Socket -> IO SSL
SSL.connection SSLContext
ctx Socket
sock
                    SSL -> String -> IO ()
SSL.setTlsextHostName SSL
ssl String
serverName
                    SSL -> IO ()
SSL.connect SSL
ssl
                    IO ByteString -> (ByteString -> IO ()) -> IO () -> IO Connection
makeConnection
                        (SSL -> Int -> IO ByteString
SSL.read SSL
ssl 32752 IO ByteString
-> (ConnectionAbruptlyTerminated -> IO ByteString) -> IO ByteString
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` \(ConnectionAbruptlyTerminated
_ :: SSL.ConnectionAbruptlyTerminated) -> ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty)
                        (SSL -> ByteString -> IO ()
SSL.write SSL
ssl)
                        (Socket -> IO ()
N.close Socket
sock)

    , managerRetryableException :: SomeException -> Bool
managerRetryableException = \se :: SomeException
se ->
        case () of
          ()
            | Just (ConnectionAbruptlyTerminated
_ :: SSL.ConnectionAbruptlyTerminated) <- SomeException -> Maybe ConnectionAbruptlyTerminated
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
se -> Bool
True
            | Bool
otherwise -> ManagerSettings -> SomeException -> Bool
managerRetryableException ManagerSettings
defaultManagerSettings SomeException
se

    , managerWrapException :: forall a. Request -> IO a -> IO a
managerWrapException = \req :: Request
req ->
        let
          wrap :: SomeException -> SomeException
wrap se :: SomeException
se
            | Just (IOException
_ :: IOException)                      <- SomeException -> Maybe IOException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
se = SomeException
se'
            | Just (SomeSSLException
_ :: SSL.SomeSSLException)             <- SomeException -> Maybe SomeSSLException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
se = SomeException
se'
            | Just (ConnectionAbruptlyTerminated
_ :: SSL.ConnectionAbruptlyTerminated) <- SomeException -> Maybe ConnectionAbruptlyTerminated
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
se = SomeException
se'
            | Just (ProtocolError
_ :: SSL.ProtocolError)                <- SomeException -> Maybe ProtocolError
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
se = SomeException
se'
            | Bool
otherwise                                                        = SomeException
se
            where
              se' :: SomeException
se' = HttpException -> SomeException
forall e. Exception e => e -> SomeException
toException (Request -> HttpExceptionContent -> HttpException
HttpExceptionRequest Request
req (SomeException -> HttpExceptionContent
InternalException SomeException
se))
        in
          (SomeException -> IO a) -> IO a -> IO a
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle (SomeException -> IO a
forall e a. Exception e => e -> IO a
throwIO (SomeException -> IO a)
-> (SomeException -> SomeException) -> SomeException -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> SomeException
wrap)
    }