Hagia
log in
morj / dwierz
overview
files
history
wiki
Viewing at
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DerivingVia #-}

module Main where

import qualified Avahi
import qualified Data.Binary.Builder as Builder
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BS8
import qualified Data.HashMap.Strict as HashMap
import qualified Data.Map.Strict as Map
import qualified Data.Ron as Ron
import qualified Dormouse.Uri as Uri
import qualified Dormouse.Uri.Encode as Uri
import qualified Network.HTTP.Client as Client
import qualified Network.Wai as Wai
import qualified Network.Wai.Handler.Warp as Warp

import Control.Concurrent.MVar (MVar, newMVar, withMVar)
import Data.ByteString (ByteString)
import Data.Function ((&))
import Data.HashMap.Strict (HashMap)
import Data.String (fromString)
import Data.Text.Encoding (encodeUtf8)
import Dormouse.Uri (Uri (..))
import GHC.Generics (Generic)
import Network.Wai (Application)
import System.Environment (getArgs)
import System.IO.Unsafe (unsafePerformIO)

logLock :: MVar ()
logLock = unsafePerformIO $ newMVar ()
{-# NOINLINE logLock #-}

logMessage :: String -> IO ()
logMessage s = withMVar logLock $ \() -> putStrLn s

data Service = Service
{ port :: !Int
}
deriving (Eq, Show, Generic)
deriving
(Ron.FromRon)
via Ron.RonWith '[Ron.EncodeWith Ron.SkipSingleConstructor] Service

newtype ServicesConfig = ServicesConfig {getServicesConfig :: HashMap ByteString Service}
deriving (Eq, Show)

instance Ron.FromRon ServicesConfig where
fromRon = fmap (ServicesConfig . HashMap.fromList . Map.toList) . Ron.fromRon

data Settings = Settings
{ bindAddress :: !String
, bindPort :: !Int
, servicesPath :: !FilePath
}
deriving (Eq, Show)

defaultSettings :: Settings
defaultSettings =
Settings
{ bindAddress = "127.0.0.1"
, bindPort = 3000
, servicesPath = "./services.ron"
}

parseCommandLine :: [String] -> Settings -> Settings
parseCommandLine = go
where
go [] !x = x
go ("--bind" : host : rest) !x = go rest $! x{bindAddress = host}
go (('-' : '-' : 'b' : 'i' : 'n' : 'd' : '=' : host) : rest) !x = go rest $! x{bindAddress = host}
go ("--port" : port : rest) !x = go rest $! x{bindPort = read port}
go (('-' : '-' : 'p' : 'o' : 'r' : 't' : '=' : port) : rest) !x = go rest $! x{bindPort = read port}
go ("--services" : services : rest) !x = go rest $! x{servicesPath = services}
go (('-' : '-' : 's' : 'e' : 'r' : 'v' : 'i' : 'c' : 'e' : 's' : '=' : services) : rest) !x = go rest $! x{servicesPath = services}
go (_noMatch : rest) !x = go rest x

reverseProxyApp :: ServicesConfig -> Client.Manager -> Application
reverseProxyApp (ServicesConfig config) manager incomingRequest respond = do
logMessage $ "Got request:\n" <> show incomingRequest <> "\n"
let mbService = (HashMap.!?) config =<< Wai.requestHeaderHost incomingRequest
case mbService of
Just Service{port} -> proxyTo port manager incomingRequest respond
Nothing ->
let Service{port} = snd . head . HashMap.toList $ config
in proxyTo port manager incomingRequest respond
-- Nothing -> respond $ Wai.responseLBS (toEnum 404) mempty "Service not found"

proxyTo :: Int -> Client.Manager -> Wai.Request -> (Wai.Response -> IO a) -> IO a
proxyTo targetPort manager incomingRequest respond = do
let streamsRequestBody = case Wai.requestBodyLength incomingRequest of
Wai.KnownLength len -> Client.RequestBodyStream $ fromIntegral len
Wai.ChunkedBody -> Client.RequestBodyStreamChunked
streamsRequestBody
:: ((Client.Popper -> IO ()) -> IO ()) -> Client.RequestBody
let requestBody =
streamsRequestBody $ \needsPopper ->
needsPopper $ Wai.getRequestBodyChunk incomingRequest

let request =
Client.defaultRequest
{ Client.method = Wai.requestMethod incomingRequest
, Client.secure = False
, Client.host = "localhost"
, Client.port = targetPort
, Client.path = Wai.rawPathInfo incomingRequest
, Client.queryString = Wai.rawQueryString incomingRequest
, Client.requestHeaders =
map (fixHost . fixReferer)
. filter (not . strippedHeader)
$ Wai.requestHeaders incomingRequest
, Client.requestBody = requestBody
, Client.proxy = Nothing
, Client.decompress = const False
, Client.redirectCount = 0
, Client.cookieJar = Nothing
, Client.requestVersion = Wai.httpVersion incomingRequest
}
Client.withResponse request manager $ \resp -> do
logMessage $ "Upstream response:\n" <> show resp{Client.responseBody = "<lazy body>" :: String} <> "\n"
let status = Client.responseStatus resp
let headers = Client.responseHeaders resp
let nextBodyChunk = Client.responseBody resp
respond $ Wai.responseStream status headers $ \sendChunk flush ->
let sendChunk' = sendChunk . Builder.fromByteString
in resendFlushingBody nextBodyChunk sendChunk' flush
where
strippedHeader (k, v) =
k
`elem` [ "accept-encoding"
, "content-encoding"
, "content-length"
, "transfer-encoding"
]
|| k == "connection" && v == "close"
fixHost (k, v)
| k == "host" = (k, "localhost:" <> BS8.pack (show targetPort))
| otherwise = (k, v)
fixReferer (k, v)
| k == "referer" || k == "origin" = (k, changeUriHost targetPort v)
| otherwise = (k, v)

changeUriHost :: Int -> ByteString -> ByteString
changeUriHost port bs = case Uri.parseUri bs of
Left _e -> bs
Right uri ->
let auth = case Uri.uriAuthority uri of
Nothing -> Nothing
Just x ->
Just
x
{ Uri.authorityHost = "localhost"
, Uri.authorityPort = Just port
}
in renderUri uri{Uri.uriAuthority = auth}

renderUri :: Uri -> ByteString
renderUri uri =
encodeUtf8 (Uri.unScheme $ Uri.uriScheme uri)
<> "://"
<> maybe "" encodeAuth (Uri.uriAuthority uri)
<> Uri.encodePath (Uri.uriPath uri)
<> maybePrefixed "?" Uri.encodeQuery (Uri.uriQuery uri)
<> maybePrefixed "#" (encodeUtf8 . Uri.unFragment) (Uri.uriFragment uri)
where
maybePrefixed _ _ Nothing = ""
maybePrefixed c f (Just val) = c <> f val
encodeAuth auth =
maybe "" ((<> "@") . encodeUtf8 . Uri.unUserInfo) (Uri.authorityUserInfo auth)
<> encodeUtf8 (Uri.unHost $ Uri.authorityHost auth)
<> maybePrefixed ":" (BS8.pack . show) (Uri.authorityPort auth)

resendFlushingBody :: IO ByteString -> (ByteString -> IO ()) -> IO () -> IO ()
resendFlushingBody getNext sendNext flush = go 0
where
go !len
| len >= 4096 = flush >> go 0
| otherwise = do
chunk <- getNext
if BS.length chunk == 0
then flush >> pure ()
else sendNext chunk >> go (len + BS.length chunk)

main :: IO ()
main = do
appSettings@Settings{bindAddress, bindPort, servicesPath} <-
flip parseCommandLine defaultSettings <$> getArgs
putStrLn $ "App settings: " <> show appSettings

services <- Ron.decodeFile servicesPath
putStrLn $ "Read services: " <> show services

avahiClient <- Avahi.createClient
myHostname <- Avahi.getHostName avahiClient
putStrLn $ "Hostname: " <> show myHostname

networkManager <- Client.newManager Client.defaultManagerSettings

let warpSettings =
Warp.defaultSettings
& Warp.setPort bindPort
& Warp.setHost (fromString bindAddress)
Warp.runSettings warpSettings $ reverseProxyApp services networkManager