grim/containers/oauth-kludge

Initial proof of concept

2020-11-22, Gary Kramlich
dcbfc4fd58e8
Initial proof of concept
package main
import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/go-http-utils/logger"
"github.com/kelseyhightower/envconfig"
)
type config struct {
ListenAddr string `envconfig:"LISTEN_ADDR" default:":8080"`
Scope string `envconfig:"SCOPE" required:"true"`
TokenEndpoint *url.URL `envconfig:"TOKEN_ENDPOINT" required:"true"`
}
var cfg config
var client *http.Client = &http.Client{Timeout: 5 * time.Second}
func sendError(w http.ResponseWriter, err error, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
w.Write([]byte(fmt.Sprintf("{\"error\":\"%v\"}", err)))
}
func copyHeader(src, dest http.Header, name string) {
if value, found := src[name]; found {
dest[name] = value
}
}
func kludge(target *url.URL, scope string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// parse the form
if err := r.ParseForm(); err != nil {
sendError(w, err, http.StatusBadRequest)
return
}
if _, found := r.Form["scope"]; !found {
r.Form.Set("scope", scope)
}
newBody := strings.NewReader(r.Form.Encode())
newRequest, err := http.NewRequest(r.Method, target.String(), newBody)
if err != nil {
sendError(w, err, http.StatusInternalServerError)
return
}
// Remove the User-Agent header
newRequest.Header.Set("User-Agent", "")
// Copy the Content-Type header
newRequest.Header.Set("Content-Type", r.Header.Get("Content-Type"))
resp, err := client.Do(newRequest)
if err != nil {
sendError(w, err, http.StatusInternalServerError)
return
}
src := resp.Header
dest := w.Header()
copyHeader(src, dest, "Content-Type")
copyHeader(src, dest, "Cache-Control")
copyHeader(src, dest, "Pragma")
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
resp.Body.Close()
})
}
func newServer() *http.Server {
handler := logger.Handler(
kludge(cfg.TokenEndpoint, cfg.Scope),
os.Stdout,
logger.CommonLoggerType,
)
return &http.Server{
Addr: cfg.ListenAddr,
Handler: handler,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
}
}
func main() {
err := envconfig.Process("OAUTH_KLUDGE", &cfg)
if err != nil {
fmt.Printf("error: %v\n", err)
return
}
// create a signal channel for catching os signals
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM)
// create our error channel
errChan := make(chan error, 1)
server := newServer()
go func() {
defer server.Close()
fmt.Printf("Listening on %s\n", cfg.ListenAddr)
errChan <- server.ListenAndServe()
}()
for {
select {
case err := <-errChan:
fmt.Printf("error: %v\n", err)
return
case s := <-signalChan:
fmt.Printf("caught %v\n", s)
return
}
}
}