lots of refactoring - also a fix for #56

pull/69/head
James Batt 4 years ago
parent 075ed916ce
commit 88370da756

@ -11,6 +11,8 @@ require (
github.com/docker/docker v1.13.1 // indirect
github.com/docker/libnetwork v0.8.0-dev.2.0.20200217033114-6659f7f4d8c1
github.com/golang/protobuf v1.4.2
github.com/google/uuid v1.1.1
github.com/gorilla/handlers v1.4.2
github.com/gorilla/mux v1.7.4
github.com/gorilla/sessions v1.2.0
github.com/gorilla/websocket v1.4.2 // indirect

@ -68,8 +68,12 @@ github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w=
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gorilla/handlers v1.4.2 h1:0QniY0USkHQ1RGCLfKxeNHK9bkDHGRYGNDFBCS+YARg=
github.com/gorilla/handlers v1.4.2/go.mod h1:Qkdc/uu4tH4g6mTK6auzZ766c4CA0Ng8+o/OAirnOIQ=
github.com/gorilla/mux v1.7.4 h1:VuZ8uybHlWmqV03+zRzdwKL4tUnIp1MAQtp1mIFE1bc=
github.com/gorilla/mux v1.7.4/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=

@ -3,7 +3,6 @@ package config
import (
"fmt"
"io/ioutil"
"net"
"os"
"path/filepath"
"runtime"
@ -223,19 +222,14 @@ func defaultInterface() (string, error) {
return "", errors.New("could not determine the default network interface name")
}
func linkIPAddr(name string) (net.IP, error) {
link, err := netlink.LinkByName(name)
if err != nil {
return nil, errors.Wrapf(err, "failed to find network interface %s", name)
}
routes, err := netlink.RouteList(link, 4)
if err != nil {
return nil, errors.Wrapf(err, "failed to list routes for interface %s", link.Attrs().Name)
}
for _, route := range routes {
if route.Src != nil {
return route.Src, nil
}
}
return nil, fmt.Errorf("no source IP found for interface %s", link.Attrs().Name)
}
// func randomPassword() string {
// letterRunes := []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
// length := 12
// b := make([]rune, length)
// for i := range b {
// b[i] = letterRunes[rand.Intn(len(letterRunes))]
// }
// return string(b)
// }

@ -11,7 +11,7 @@ import (
func metadataLoop(d *DeviceManager) {
for {
syncMetrics(d)
time.Sleep(5 * time.Second)
time.Sleep(30 * time.Second)
}
}

@ -0,0 +1,55 @@
package services
import (
"context"
"fmt"
"math"
"net/http"
grpc_logrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus"
"github.com/improbable-eng/grpc-web/go/grpcweb"
"github.com/place1/wg-access-server/internal/config"
"github.com/place1/wg-access-server/internal/devices"
"github.com/place1/wg-access-server/internal/traces"
"github.com/place1/wg-access-server/proto/proto"
"google.golang.org/grpc"
)
type ApiServices struct {
Config *config.AppConfig
DeviceManager *devices.DeviceManager
}
func ApiRouter(deps *ApiServices) http.Handler {
// Native GRPC server
server := grpc.NewServer([]grpc.ServerOption{
grpc.MaxRecvMsgSize(int(1 * math.Pow(2, 20))), // 1MB
grpc.UnaryInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
return grpc_logrus.UnaryServerInterceptor(traces.Logger(ctx))(ctx, req, info, handler)
}),
}...)
// Register GRPC services
proto.RegisterServerServer(server, &ServerService{
Config: deps.Config,
})
proto.RegisterDevicesServer(server, &DeviceService{
DeviceManager: deps.DeviceManager,
})
// Grpc Web in process proxy (wrapper)
grpcServer := grpcweb.WrapServer(server,
grpcweb.WithAllowNonRootResource(true),
)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if grpcServer.IsGrpcWebRequest(r) {
grpcServer.ServeHTTP(w, r)
return
}
w.WriteHeader(400)
fmt.Fprintln(w, "expected grpc request")
return
})
}

@ -4,13 +4,13 @@ import (
"context"
"time"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"
"github.com/place1/wg-access-server/pkg/authnz/authsession"
"github.com/golang/protobuf/ptypes/empty"
"github.com/place1/wg-access-server/internal/devices"
"github.com/place1/wg-access-server/internal/storage"
"github.com/place1/wg-access-server/proto/proto"
"github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
@ -27,7 +27,7 @@ func (d *DeviceService) AddDevice(ctx context.Context, req *proto.AddDeviceReq)
device, err := d.DeviceManager.AddDevice(user, req.GetName(), req.GetPublicKey())
if err != nil {
logrus.Error(err)
ctxlogrus.Extract(ctx).Error(err)
return nil, status.Errorf(codes.Internal, "failed to add device")
}
@ -42,7 +42,7 @@ func (d *DeviceService) ListDevices(ctx context.Context, req *proto.ListDevicesR
devices, err := d.DeviceManager.ListDevices(user.Subject)
if err != nil {
logrus.Error(err)
ctxlogrus.Extract(ctx).Error(err)
return nil, status.Errorf(codes.Internal, "failed to retrieve devices")
}
return &proto.ListDevicesRes{
@ -57,7 +57,7 @@ func (d *DeviceService) DeleteDevice(ctx context.Context, req *proto.DeleteDevic
}
if err := d.DeviceManager.DeleteDevice(user.Subject, req.GetName()); err != nil {
logrus.Error(err)
ctxlogrus.Extract(ctx).Error(err)
return nil, status.Errorf(codes.Internal, "failed to delete device")
}
@ -76,7 +76,7 @@ func (d *DeviceService) ListAllDevices(ctx context.Context, req *proto.ListAllDe
devices, err := d.DeviceManager.ListAllDevices()
if err != nil {
logrus.Error(err)
ctxlogrus.Extract(ctx).Error(err)
return nil, status.Errorf(codes.Internal, "failed to retrieve devices")
}

@ -0,0 +1,13 @@
package services
import (
"fmt"
"net/http"
)
func HealthEndpoint() http.Handler {
return http.HandlerFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
fmt.Fprintf(w, "ok")
}))
}

@ -0,0 +1,31 @@
package services
import (
"fmt"
"net/http"
"runtime/debug"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"
"github.com/place1/wg-access-server/internal/traces"
)
func TracesMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r.WithContext(traces.WithTraceID(r.Context())))
})
}
func RecoveryMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
ctxlogrus.Extract(r.Context()).
WithField("stack", string(debug.Stack())).
Error(err)
w.WriteHeader(500)
fmt.Fprintf(w, "server error\ntrace = %s\n", traces.TraceID(r.Context()))
}
}()
next.ServeHTTP(w, r)
})
}

@ -4,13 +4,13 @@ import (
"context"
"strings"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"
"github.com/place1/wg-access-server/internal/network"
"github.com/place1/wg-access-server/internal/config"
"github.com/place1/wg-access-server/pkg/authnz/authsession"
"github.com/place1/wg-access-server/proto/proto"
"github.com/place1/wg-embed/pkg/wgembed"
"github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
@ -27,7 +27,7 @@ func (s *ServerService) Info(ctx context.Context, req *proto.InfoReq) (*proto.In
publicKey, err := wgembed.PublicKey(s.Config.WireGuard.InterfaceName)
if err != nil {
logrus.Error(err)
ctxlogrus.Extract(ctx).Error(err)
return nil, status.Errorf(codes.Internal, "failed to get public key")
}

@ -0,0 +1,89 @@
package services
import (
"net/http"
"net/http/httputil"
"net/url"
"os"
"path"
"path/filepath"
"strings"
"github.com/gorilla/mux"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
func WebsiteRouter() *mux.Router {
router := mux.NewRouter()
staticFiles, err := filepath.Abs("website/build")
if err != nil {
logrus.Fatal(errors.Wrap(err, "failed to create absolut path to website static files"))
}
if _, err := os.Stat(staticFiles); os.IsNotExist(err) {
// if the static files directory doesn't exist
// then proxy to a local webpack development server
// i.e. we're developing wg-access-server locally
logrus.Info("missing ./website/build - will reverse proxy to website dev server")
u, _ := url.Parse("http://localhost:3000")
router.NotFoundHandler = httputil.NewSingleHostReverseProxy(u)
} else {
// if the static files directory exists then
// handle static file requests.
// the react app handles routing so we also
// add a catch-all route to serve the react index page.
logrus.Info("serving website from ./website/build")
router.PathPrefix("/").Handler(
FileServerWith404(
http.Dir(staticFiles),
func(w http.ResponseWriter, r *http.Request) bool {
http.ServeFile(w, r, filepath.Join(staticFiles, "index.html"))
return false
},
),
)
}
return router
}
// credit: https://gist.github.com/lummie/91cd1c18b2e32fa9f316862221a6fd5c
type FSHandler404 = func(w http.ResponseWriter, r *http.Request) (doDefaultFileServe bool)
// credit: https://gist.github.com/lummie/91cd1c18b2e32fa9f316862221a6fd5c
func FileServerWith404(root http.FileSystem, handler404 FSHandler404) http.Handler {
fs := http.FileServer(root)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
//make sure the url path starts with /
upath := r.URL.Path
if !strings.HasPrefix(upath, "/") {
upath = "/" + upath
r.URL.Path = upath
}
upath = path.Clean(upath)
// attempt to open the file via the http.FileSystem
f, err := root.Open(upath)
if err != nil {
if os.IsNotExist(err) {
// call handler
if handler404 != nil {
doDefault := handler404(w, r)
if !doDefault {
return
}
}
}
}
// close if successfully opened
if err == nil {
f.Close()
}
// default serve
fs.ServeHTTP(w, r)
})
}

@ -0,0 +1,34 @@
package traces
import (
"context"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
const (
TraceIDKey = "trace.id"
)
func WithTraceID(ctx context.Context) context.Context {
id, err := uuid.NewRandom()
if err != nil {
logrus.Warn(errors.Wrap(err, "failed to generate trace id"))
return ctx
}
return context.WithValue(ctx, TraceIDKey, id.String())
}
func Logger(ctx context.Context) *logrus.Entry {
return logrus.WithField("trace.id", TraceID(ctx))
}
func TraceID(ctx context.Context) string {
if id, ok := ctx.Value(TraceIDKey).(string); ok {
return id
}
return "<no-trace-id>"
}

@ -2,16 +2,12 @@ package main
import (
"fmt"
"math"
"net/http"
"net/url"
"os"
"runtime/debug"
"github.com/place1/wg-access-server/internal/services"
"github.com/place1/wg-access-server/internal/storage"
"github.com/improbable-eng/grpc-web/go/grpcweb"
"github.com/place1/wg-access-server/proto/proto"
"github.com/place1/wg-access-server/pkg/authnz"
"github.com/place1/wg-access-server/pkg/authnz/authsession"
"github.com/gorilla/mux"
"github.com/place1/wg-embed/pkg/wgembed"
@ -21,17 +17,7 @@ import (
"github.com/place1/wg-access-server/internal/devices"
"github.com/place1/wg-access-server/internal/dnsproxy"
"github.com/place1/wg-access-server/internal/network"
"github.com/place1/wg-access-server/internal/services"
"github.com/place1/wg-access-server/pkg/authnz"
"github.com/place1/wg-access-server/pkg/authnz/authsession"
"github.com/sirupsen/logrus"
"net/http/httputil"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_logrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus"
grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
"google.golang.org/grpc"
)
func main() {
@ -94,75 +80,50 @@ func main() {
logrus.Fatal(errors.Wrap(err, "failed to sync"))
}
// Router
router := mux.NewRouter()
router.Use(services.TracesMiddleware)
router.Use(services.RecoveryMiddleware)
// if the built website exists, serve that
// otherwise proxy to a local webpack development server
if _, err := os.Stat("website/build"); os.IsNotExist(err) {
u, _ := url.Parse("http://localhost:3000")
router.NotFoundHandler = httputil.NewSingleHostReverseProxy(u)
} else {
router.PathPrefix("/").Handler(http.FileServer(http.Dir("website/build")))
}
// GRPC Server
server := grpc.NewServer([]grpc.ServerOption{
grpc.MaxRecvMsgSize(int(1 * math.Pow(2, 20))), // 1MB
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
grpc_logrus.UnaryServerInterceptor(logrus.NewEntry(logrus.StandardLogger())),
grpc_recovery.UnaryServerInterceptor(),
)),
}...)
proto.RegisterServerServer(server, &services.ServerService{
Config: conf,
})
proto.RegisterDevicesServer(server, &services.DeviceService{
DeviceManager: deviceManager,
})
grpcServer := grpcweb.WrapServer(server)
var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
logrus.WithField("stack", string(debug.Stack())).Error(err)
}
}()
if grpcServer.IsGrpcWebRequest(r) {
grpcServer.ServeHTTP(w, r)
} else {
if authsession.Authenticated(r.Context()) {
router.ServeHTTP(w, r)
} else {
http.Redirect(w, r, "/signin", http.StatusTemporaryRedirect)
}
}
})
// Health check endpoint
router.PathPrefix("/health").Handler(services.HealthEndpoint())
// Authentication middleware
if conf.Auth.IsEnabled() {
handler = authnz.New(conf.Auth, func(user *authsession.Identity) error {
if user.Subject == conf.AdminSubject {
user.Claims.Add("admin", "true")
}
return nil
}).Wrap(handler)
router.Use(authnz.NewMiddleware(conf.Auth, claimsMiddleware(conf)))
} else {
base := handler
handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
base.ServeHTTP(w, r.WithContext(authsession.SetIdentityCtx(r.Context(), &authsession.AuthSession{
Identity: &authsession.Identity{
Subject: "",
},
})))
logrus.Warn("[DEPRECATION NOTICE] using wg-access-server without an admin user is deprecated and will be removed in an upcoming minior release.")
router.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r.WithContext(authsession.SetIdentityCtx(r.Context(), &authsession.AuthSession{
Identity: &authsession.Identity{
Subject: "",
},
})))
})
})
}
publicRouter := mux.NewRouter()
publicRouter.Handle("/health", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
fmt.Fprintf(w, "ok")
})).Methods("GET")
publicRouter.NotFoundHandler = handler
// Subrouter for our site (web + api)
site := router.PathPrefix("/").Subrouter()
site.Use(authnz.RequireAuthentication)
// Grpc api
site.PathPrefix("/api").Handler(services.ApiRouter(&services.ApiServices{
Config: conf,
DeviceManager: deviceManager,
}))
// Static website
site.PathPrefix("/").Handler(services.WebsiteRouter())
// publicRouter.NotFoundHandler = authMiddleware.Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// if authsession.Authenticated(r.Context()) {
// router.ServeHTTP(w, r)
// } else {
// http.Redirect(w, r, "/signin", http.StatusTemporaryRedirect)
// }
// }))
publicRouter := router
// Listen
address := fmt.Sprintf("0.0.0.0:%d", conf.Port)
@ -177,3 +138,12 @@ func main() {
logrus.Fatal(errors.Wrap(err, "unable to start http server"))
}
}
func claimsMiddleware(conf *config.AppConfig) authsession.ClaimsMiddleware {
return func(user *authsession.Identity) error {
if user.Subject == conf.AdminSubject {
user.Claims.Add("admin", "true")
}
return nil
}
}

@ -47,6 +47,7 @@ func basicAuthLogin(c *BasicAuthConfig, runtime *authruntime.ProviderRuntime) ht
},
})
runtime.Done(w, r)
return
}
w.Header().Set("WWW-Authenticate", `Basic realm="site"`)

@ -5,8 +5,8 @@ import (
"net/http"
"strconv"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/place1/wg-access-server/pkg/authnz/authconfig"
"github.com/place1/wg-access-server/pkg/authnz/authruntime"
@ -21,18 +21,15 @@ import (
type AuthMiddleware struct {
config authconfig.AuthConfig
claimsMiddleware authsession.ClaimsMiddleware
router *mux.Router
runtime *authruntime.ProviderRuntime
}
func New(config authconfig.AuthConfig, claimsMiddleware authsession.ClaimsMiddleware) *AuthMiddleware {
return &AuthMiddleware{config, claimsMiddleware}
}
func (m *AuthMiddleware) Wrap(next http.Handler) http.Handler {
runtime := authruntime.NewProviderRuntime(sessions.NewCookieStore([]byte(authutil.RandomString(32))))
router := mux.NewRouter()
providers := m.config.Providers()
store := sessions.NewCookieStore([]byte(authutil.RandomString(32)))
runtime := authruntime.NewProviderRuntime(store)
providers := config.Providers()
for _, p := range providers {
if p.RegisterRoutes != nil {
@ -49,9 +46,9 @@ func (m *AuthMiddleware) Wrap(next http.Handler) http.Handler {
router.HandleFunc("/signin/{index}", func(w http.ResponseWriter, r *http.Request) {
index, err := strconv.Atoi(mux.Vars(r)["index"])
if err != nil || (index < 0 || index >= len(providers)) {
fmt.Fprintf(w, "unknown provider")
if err != nil || index < 0 || len(providers) <= index {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "unknown provider")
return
}
provider := providers[index]
@ -63,11 +60,35 @@ func (m *AuthMiddleware) Wrap(next http.Handler) http.Handler {
runtime.Restart(w, r)
})
router.PathPrefix("/").Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if s, err := runtime.GetSession(r); err == nil {
return &AuthMiddleware{
config,
claimsMiddleware,
router,
runtime,
}
}
func NewMiddleware(config authconfig.AuthConfig, claimsMiddleware authsession.ClaimsMiddleware) mux.MiddlewareFunc {
return New(config, claimsMiddleware).Middleware
}
func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// check if the request is for an auth
// related page i.e. /signin
// to be handled by our own router
if ok := m.router.Match(r, &mux.RouteMatch{}); ok {
m.router.ServeHTTP(w, r)
return
}
// otherwise we apply the standard middleware
// functionality i.e. annotate the request context
// with the request user (identity)
if s, err := m.runtime.GetSession(r); err == nil {
if m.claimsMiddleware != nil {
if err := m.claimsMiddleware(s.Identity); err != nil {
logrus.Error(errors.Wrap(err, "authz middleware failure"))
ctxlogrus.Extract(r.Context()).Error(errors.Wrap(err, "authz middleware failure"))
http.Error(w, "internal server error", http.StatusInternalServerError)
return
}
@ -76,11 +97,15 @@ func (m *AuthMiddleware) Wrap(next http.Handler) http.Handler {
} else {
next.ServeHTTP(w, r)
}
}))
return router
})
}
func indexHandler(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Index")
func RequireAuthentication(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if authsession.Authenticated(r.Context()) {
next.ServeHTTP(w, r)
} else {
http.Redirect(w, r, "/signin", http.StatusTemporaryRedirect)
}
})
}

@ -2,7 +2,7 @@ import { Timestamp } from 'google-protobuf/google/protobuf/timestamp_pb';
import { Devices } from './sdk/devices_pb';
import { Server } from './sdk/server_pb';
const backend = window.location.origin;
const backend = window.location.origin + '/api';
export const grpc = {
server: new Server(backend),

Loading…
Cancel
Save